diff --git a/README.md b/README.md index 77eeaca..2647668 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,41 @@ -# cite -Implementation for our paper "Conditional Image-Text Embedding Networks" (coming soon) +# Conditional Image-Text Embedding Networks + +**cite** contains a Tensorflow implementation for our [paper](https://arxiv.org/abs/1711.08389). If you find this code useful in your research, please consider citing: + + @article{plummerCITE2017, + Author = {Bryan A. Plummer and Paige Kordas and M. Hadi Kiapour and Shuai Zheng and Robinson Piramuthu and Svetlana Lazebnik}, + Title = {Conditional Image-Text Embedding Networks}, + Journal = {arXiv:1711.08389}, + Year = {2017} + } + +This code was tested on an Ubuntu 16.04 system using Tensorflow 1.2.1. + +### Phrase Localization Evaluation Demo +After you download our precomputed features/model you can test it using: + + ```Shell + python main.py --test --spatial --name runs/cite_spatial_k4/model_best + ``` + +### Training New Models +Our code contains everything required to train or test models using precomputed features. You can train a new model on Flickr30K Entites using: + + ```Shell + python main.py --name + ``` + +When it completes training it will output the localization accuracy using the best model on the testing and validation sets. You can see a listing and description of many tuneable parameters with: + + ```Shell + python main.py --help + ``` + +### Precomputed Features + +Along with our example data processing script in `data_processing_example` you can download our precomputed (PASCAL) features for the Flickr30K Entities dataset [here](https://drive.google.com/file/d/1m5DQ3kh2rCkPremgM91chQgJYZxnEbZw/view?usp=sharing) (52G). Unpack the features in a folder named `data` or update the path in the data loader class. + +Our best CITE model on Flickr30K Entities using these precomputed features can be found [here](https://drive.google.com/open?id=1rmeIqYTCIduNc2QWUEdXLHFGrlOzz2xO). + + +Many thanks to [Kevin Shih](https://scholar.google.com/citations?user=4x3DhzAAAAAJ&hl=en) and [Liwei Wang](https://scholar.google.com/citations?user=qnbdnZEAAAAJ&hl=en) for access to their Similarity Network code that was used as the basis for this implementation. \ No newline at end of file diff --git a/data_loader.py b/data_loader.py new file mode 100644 index 0000000..f7277c4 --- /dev/null +++ b/data_loader.py @@ -0,0 +1,123 @@ +import numpy as np +import h5py +import os + +class DataLoader: + """Class minibatches from data on disk in HDF5 format""" + def __init__(self, args, region_dim, phrase_dim, plh, split): + """Constructor + + Arguments: + args -- command line arguments passed into the main function + region_dim -- dimensions of the region features + phrase_dim -- dimensions of the phrase features + plh -- placeholder dictory containing the tensor inputs + split -- the data split (i.e. 'train', 'test', 'val') + """ + datafn = os.path.join('data', 'flickr', '%s_imfeats.h5' % split) + self.data = h5py.File(datafn, 'r') + vecs = np.array(self.data['phrase_features'], np.float32) + uniquePhrases = list(self.data['phrases']) + assert(vecs.shape[0] == len(uniquePhrases)) + + w2v_dict = {} + for index, phrase in enumerate(uniquePhrases): + w2v_dict[phrase] = vecs[index, :] + + # mapping from uniquePhrase to w2v + self.w2v_dict = w2v_dict + self.pairs = list(self.data['pairs']) + self.n_pairs = len(self.pairs[0]) + self.pair_index = range(self.n_pairs) + + self.uniquePhrases = uniquePhrases # set of unique phrases + + self.split = split + self.plh = plh + self.is_train = split == 'train' + self.neg_to_pos_ratio = args.neg_to_pos_ratio + self.batch_size = args.batch_size + self.max_boxes = args.max_boxes + if self.is_train: + self.success_thresh = args.train_success_thresh + else: + self.success_thresh = args.test_success_thresh + + self.region_feature_dim = region_dim + self.phrase_feature_dim = phrase_dim + + def __len__(self): + return self.n_pairs + + def shuffle(self): + ''' Shuffles the order of the pairs being sampled + ''' + np.random.shuffle(self.pair_index) + + def num_batches(self): + return int(np.ceil(float(len(self)) / self.batch_size)) + + def get_batch(self, batch_id): + """Returns a minibatch given a valid id for it + + Arguments: + batch_id -- number between 0 and self.num_batches() + + Returns: + feed_dict -- dictionary containing minibatch data + gt_labels -- indicates positive/negative regions + num_pairs -- number of pairs without padding + """ + region_features = np.zeros((self.batch_size, self.max_boxes, + self.region_feature_dim), dtype=np.float32) + start_pair = batch_id * self.batch_size + end_pair = min(start_pair + self.batch_size, len(self)) + num_pairs = end_pair - start_pair + gt_labels = np.zeros((self.batch_size, self.max_boxes), + dtype=np.float32) + phrase_features = np.zeros((self.batch_size, self.phrase_feature_dim), + dtype=np.float32) + for pair_id in range(num_pairs): + sample_id = self.pair_index[start_pair + pair_id] + + # paired image + im_id = self.pairs[0][sample_id] + + # paired phrase + phrase = self.pairs[1][sample_id] + + # phrase instance identifier + p_id = self.pairs[2][sample_id] + + # gets region features + features = np.array(self.data[im_id], np.float32) + num_boxes = min(len(features), self.max_boxes) + features = features[:num_boxes, :self.region_feature_dim] + overlaps = np.array(self.data['%s_%s_%s' % (im_id, phrase, p_id)]) + + # last 4 dimensions of overlaps are ground truth box coordinates + assert(num_boxes <= len(overlaps) - 4) + overlaps = overlaps[:num_boxes] + region_features[pair_id, :num_boxes,:] = features + phrase_features[pair_id, :] = self.w2v_dict[phrase] + gt_labels[pair_id, :num_boxes] = overlaps >= self.success_thresh + if self.is_train: + num_pos = int(np.sum(gt_labels[pair_id, :])) + num_neg = num_pos * self.neg_to_pos_ratio + negs = np.random.permutation(np.where(overlaps < 0.3)[0]) + + if len(negs) < num_neg: # if not enough negatives + negs = np.random.permutation(np.where(overlaps < 0.4)[0]) + + # logistic loss only counts a region labeled as -1 negative + gt_labels[pair_id, negs[:num_neg]] = -1 + + feed_dict = {self.plh['phrase'] : phrase_features, + self.plh['region'] : region_features, + self.plh['train_phase'] : self.is_train, + self.plh['num_boxes'] : self.max_boxes, + self.plh['labels'] : gt_labels + } + + return feed_dict, gt_labels, num_pairs + diff --git a/data_processing_example/README.md b/data_processing_example/README.md new file mode 100644 index 0000000..48ebe66 --- /dev/null +++ b/data_processing_example/README.md @@ -0,0 +1,12 @@ +# Conditional Image-Text Embedding Networks + +The code currently assumes datasets are divided into three hdf5 files named `_imfeats.h5` where `split` takes on the value train, test, or val. It assumes it has the following items: + +1. phrase_features: #num_phrase X 6000 dimensional matrix of phrase features +2. phrases: array of #num_phrase strings corresponding to the phrase features +3. pairs: 3 x M matrix where each column contains a string representation for the `[image name, phrase, pair identifier]` pairs in the split. +4. Each `` should return a #num_boxes x feature_dimensional matrix of the visual features. The features should contain the visual representation as well as the spatial features for the box followed by its coordinates (i.e. the precomputed features we released are 4096 (VGG) + 5 (spatial) + 4 (box coordinates) = 4105 dimensional). +5. Each `__ should contain a vector containing the intersection over union with the ground truth box followed by the box's coordinates (i.e. for N boxes the vector should be N + 4 dimensional). + + +The example script uses the [pl-clc](https://github.com/BryanPlummer/pl-clc) repo for parsing and computing features of the Flick30K Entities dataset. It assumes it uses the built-in MATLAB PCA function, and not the one in the `toolbox` external module. \ No newline at end of file diff --git a/data_processing_example/process_data.m b/data_processing_example/process_data.m new file mode 100644 index 0000000..2cb0729 --- /dev/null +++ b/data_processing_example/process_data.m @@ -0,0 +1,117 @@ +% Code should be placed in the pl-clc directory to run and assumes +% the dataset has been downloaded as specified in that repo. +net = fullfile('models', 'voc_2007_trainvaltest_2012_trainval', 'vgg16_fast_rcnn_iter_100000.caffemodel'); +def = fullfile('models', 'fastrcnn_feat.prototxt'); +image_dir = fullfile('datasets', 'Flickr30kEntities', 'Images'); +output_dir = '.'; + +% code assumes train is the first to be processed +splits = {'train', 'test', 'val'}; +phraseCoeff = []; +phraseMean = []; +for i = 1:length(splits) + load(fullfile('data', 'flickr30k', sprintf('%sData.mat',splits{i}),'imData'); + + imData.filterNonvisualPhrases(); + if strcmp(splits{i},'train') + imData.concatenateGTBoxes(); + end + + % getPhraseWords returns a nested image x sentence x phrase + % cell array + uniquePhrases = imData.getPhraseWords(); + uniquePhrases = vertcat(uniquePhrases{:}); + maxPhrases = max(cellfun(@length, uniquePhrases)); + uniquePhrases = vertcat(uniquePhrases{:}); + uniquePhrases(cellfun(@isempty, uniquePhrases)) = []; + uniquePhrases = cellfun(@(f)strrep(f,'+',' '), uniquePhrases, 'UniformOutput',false); + uniquePhrases = cellfun(@(f)strtrim(removePunctuation(f)), uniquePhrases, 'UniformOutput',false); + uniquePhrases(cellfun(@isempty, uniquePhrases)) = []; + uniquePhrases = unique(uniquePhrases); + uniquePhrases = cellfun(@(f)strrep(f,' ','+'), uniquePhrases, 'UniformOutput',false); + phraseFeatures = single(getHGLMMFeatures(uniquePhrases))'; + uniquePhrases = [uniquePhrases;{'unk'}]; + + % compute PCA parameters on first iteration (train split) + if isempty(phraseCoeff) + phraseMean = mean(phraseFeatures, 1); + phraseFeatures = bsxfun(@minus, phraseFeatures, phraseMean); + phraseCoeff = pca(phraseFeatures, 'NumComponents', 6000); + else + phrase_features = bsxfun(@minus, phraseFeatures, phraseMean); + end + + phraseFeatures = phraseFeatures * phraseCoeff; + phraseFeatures = [phraseFeatures; zeros(1, 6000, 'single')]; + phraseFeatures = phraseFeatures'; + + hdf5fn = fullfile(output_dir, sprintf('%s_imfeats.h5', splits{i})); + hdf5write(hdf5fn, 'max_phrases', maxPhrases, 'phrase_features', phraseFeatures, 'phrases', uniquePhrases); + clear phraseFeatures + + pairs = cell(imData.nImages, 1); + for x = 1:imData.nImages + boxes = imData.getBoxes(x); + pairs{x} = cell(imData.nSentences(x), 1); + for y = 1:imData.nSentences(x) + pairs{x}{y} = cell(imData.nPhrases(x, y), 1); + for z = 1:imData.nPhrases(x, y) + box = imData.getPhraseGT(x, y, z); + phrase = imData.getPhrase(x,y,z).getPhraseString(imData.stopwords); + if ~isempty(phrase) + phrase = strtrim(removePunctuation(strrep(phrase,'+',' '))); + end + + if isempty(phrase) + phrase = 'unk'; + end + phrase = strrep(phrase,' ','+'); + + p_id = sprintf('%i_%i', y, z); + pairs{x}{y}{z} = {imData.imagefns{x}, phrase, p_id}; + overlaps = getIOU(box, boxes); + phrase_id = sprintf('%s_%s_%s', imData.imagefns{x}, phrase, p_id); + overlaps = [overlaps; box']; + hdf5write(hdf5fn, phrase_id, overlaps, 'WriteMode', 'append'); + end + end + end + + pairs = vertcat(pairs{:}); + pairs = vertcat(pairs{:}); + pairs = vertcat(pairs{:}); + hdf5write(hdf5fn, 'pairs', pairs, 'WriteMode', 'append'); + + % separate into batches since getFastRCNNFeatures operates on a + % batch level + batchSize = 1000; + nBatches = ceil(imData.nImages/batchSize); + for batch = 1:batchSize:imData.nImages + batchEnd = min(batch+batchSize-1,imData.nImages); + imagefns = imData.imagefns(batch:batchEnd); + imagedir = imData.imagedir; + ext = imData.ext; + stopwords = imData.stopwords; + batchData = ImageSetData(imData.split,imagefns,imagedir,ext,stopwords); + batchData.phrase = imData.phrase(batch:batchEnd); + batchData.relationship = imData.phrase(batch:batchEnd); + batchData.annotations = imData.annotations(batch:batchEnd); + batchData.boxes = imData.boxes(batch:batchEnd); + regionFeatures = getFastRCNNFeatures(batchData, net, def); + for j = 1:batchData.nImages + boxes = batchData.getBoxes(j); + imDims = batchData.imSize(j); + boxFeatures = boxes; + boxFeatures(:, 1) = boxFeatures(:,1) / imDims(2); + boxFeatures(:, 2) = boxFeatures(:,2) / imDims(1); + boxFeatures(:, 3) = boxFeatures(:,3) / imDims(2); + boxFeatures(:, 4) = boxFeatures(:,4) / imDims(1); + boxWidth = boxes(:,3) - boxes(:,1); + boxHeight = boxes(:,4) - boxes(:,2); + boxFeatures = [boxFeatures, (boxWidth.*boxHeight)./(imDims(1)*imDims(2))]; + features = [regionFeatures{j}; boxFeatures'; boxes']; + hdf5write(hdf5fn, batchData.imagefns{j}, features, 'WriteMode', 'append'); + end + end +end + diff --git a/main.py b/main.py new file mode 100644 index 0000000..c3ccf30 --- /dev/null +++ b/main.py @@ -0,0 +1,198 @@ +import os +import sys +import argparse + +import numpy as np +import tensorflow as tf + +from model import setup_model +from data_loader import DataLoader + +parser = argparse.ArgumentParser(description='Conditional Image-Text Similarity Network') +parser.add_argument('--name', default='Conditional_Image-Text_Similarity_Network', type=str, + help='name of experiment') +parser.add_argument('--dataset', default='flickr', type=str, + help='name of the dataset to use') +parser.add_argument('--r_seed', type=int, default=42, + help='random seed (default: 42)') +parser.add_argument('--info_iterval', type=int, default=250, + help='number of batches to process before outputing training status') +parser.add_argument('--resume', default='', type=str, + help='filename of model to load (default: none)') +parser.add_argument('--test', dest='test', action='store_true', default=False, + help='Run model on test set') +parser.add_argument('--batch-size', type=int, default=200, + help='input batch size for training (default: 200)') +parser.add_argument('--lr', type=float, default=5e-5, metavar='LR', + help='learning rate (default: 5e-5)') +parser.add_argument('--embed_l1', type=float, default=5e-5, + help='weight of the L1 regularization term used on the concept weight branch (default: 5e-5)') +parser.add_argument('--max_epoch', type=int, default=0, + help='maximum number of epochs, <1 indicates no limit (default: 0)') +parser.add_argument('--no_gain_stop', type=int, default=5, + help='number of epochs used to perform early stopping based on validation performance (default: 5)') +parser.add_argument('--neg_to_pos_ratio', type=int, default=2, + help='ratio of negatives to positives used during training (default: 2)') +parser.add_argument('--minimum_gain', type=float, default=5e-4, metavar='N', + help='minimum performance gain for a model to be considered better (default: 5e-4)') +parser.add_argument('--train_success_thresh', type=float, default=0.6, + help='minimum training intersection-over-union threshold for success (default: 0.6)') +parser.add_argument('--test_success_thresh', type=float, default=0.5, + help='minimum testing intersection-over-union threshold for success (default: 0.5)') +parser.add_argument('--dim_embed', type=int, default=256, + help='how many dimensions in final embedding (default: 256)') +parser.add_argument('--max_boxes', type=int, default=200, + help='maximum number of edge boxes per image (default: 200)') +parser.add_argument('--num_embeddings', type=int, default=4, + help='number of embeddings to train (default: 4)') +parser.add_argument('--spatial', dest='spatial', action='store_true', default=False, + help='Flag indicating whether to use spatial features') + +def main(): + global args + args = parser.parse_args() + np.random.seed(args.r_seed) + tf.set_random_seed(args.r_seed) + phrase_feature_dim = 6000 + region_feature_dim = 4096 + if args.spatial: + if args.dataset == 'flickr': + region_feature_dim += 5 + else: + region_feature_dim += 8 + + # setup placeholders + labels_plh = tf.placeholder(tf.float32, shape=[args.batch_size, None]) + phrase_plh = tf.placeholder(tf.float32, shape=[args.batch_size, + phrase_feature_dim]) + region_plh = tf.placeholder(tf.float32, shape=[args.batch_size, None, + region_feature_dim]) + train_phase_plh = tf.placeholder(tf.bool, name='train_phase') + num_boxes_plh = tf.placeholder(tf.int32) + + plh = {} + plh['num_boxes'] = num_boxes_plh + plh['labels'] = labels_plh + plh['phrase'] = phrase_plh + plh['region'] = region_plh + plh['train_phase'] = train_phase_plh + + test_loader = DataLoader(args, region_feature_dim, phrase_feature_dim, + plh, 'test') + model = setup_model(args, phrase_plh, region_plh, train_phase_plh, + labels_plh, num_boxes_plh, region_feature_dim) + if args.test: + test(model, test_loader, model_name=args.resume) + sys.exit() + + save_model_directory = os.path.join('runs', args.name) + if not os.path.exists(save_model_directory): + os.makedirs(save_model_directory) + + train_loader = DataLoader(args, region_feature_dim, phrase_feature_dim, + plh, 'train') + val_loader = DataLoader(args, region_feature_dim, phrase_feature_dim, + plh, 'val') + + # training with Adam + acc, best_adam = train(model, train_loader, val_loader, args.resume) + + # finetune with SGD after loading the best model trained with Adam + best_model_filename = os.path.join('runs', args.name, 'model_best') + acc, best_sgd = train(model, train_loader, val_loader, + best_model_filename, False, acc) + best_epoch = best_adam + best_sgd + + # get performance on test set + test_acc = test(model, test_loader, model_name=best_model_filename) + print('best model at epoch {}: {:.2f}% (val {:.2f}%)'.format( + best_epoch, round(test_acc*100, 2), round(acc*100, 2))) + +def test(model, test_loader, sess=None, model_name = None): + if sess is None: + sess = tf.Session() + saver = tf.train.Saver() + saver.restore(sess, model_name) + + region_weights = model[3] + correct = 0.0 + n_iterations = test_loader.num_batches() + for batch_id in range(n_iterations): + feed_dict, gt_labels, num_pairs = test_loader.get_batch(batch_id) + scores = sess.run(region_weights, feed_dict = feed_dict) + for pair_index in range(num_pairs): + best_region_index = np.argmax(scores[pair_index, :]) + correct += gt_labels[pair_index, best_region_index] + + acc = correct/len(test_loader) + print('\n{} set localization accuracy: {:.2f}%\n'.format( + test_loader.split, round(acc*100, 2))) + return acc + +def process_epoch(model, train_loader, sess, train_step, epoch, suffix): + train_loader.shuffle() + + # extract elements from model tuple + loss = model[0] + region_loss = model[1] + l1_loss = model[2] + + n_iterations = train_loader.num_batches() + for batch_id in range(n_iterations): + feed_dict, _, _ = train_loader.get_batch(batch_id) + (_, total, region, concept_l1) = sess.run([train_step, loss, + region_loss, l1_loss], + feed_dict = feed_dict) + + if batch_id % args.info_iterval == 0: + print('loss: {:.5f} (region: {:.5f} concept: {:.5f}) ' + '[{}/{}] (epoch: {}) {}'.format(total, region, concept_l1, + (batch_id*args.batch_size), + len(train_loader), epoch, + suffix)) + +def train(model, train_loader, test_loader, model_weights, use_adam = True, + best_acc = 0.): + sess = tf.Session() + if use_adam: + optim = tf.train.AdamOptimizer(args.lr) + suffix = '' + else: + optim = tf.train.GradientDescentOptimizer(args.lr / 10.) + suffix = 'ft' + + weights_norm = tf.losses.get_regularization_losses() + weights_norm_sum = tf.add_n(weights_norm) + loss = model[0] + train_step = optim.minimize(loss + weights_norm_sum) + saver = tf.train.Saver() + init = tf.global_variables_initializer() + epoch = 1 + best_epoch = 0 + with sess.as_default(): + init.run() + if model_weights: + saver.restore(sess, model_weights) + if use_adam: + best_acc = test(model, test_loader, sess) + + # model trains until args.max_epoch is reached or it no longer + # improves on the validation set + while (epoch - best_epoch) < args.no_gain_stop and (args.max_epoch < 1 or epoch <= args.max_epoch): + process_epoch(model, train_loader, sess, train_step, epoch, suffix) + saver.save(sess, os.path.join('runs', args.name, 'checkpoint'), + global_step = epoch) + acc = test(model, test_loader, sess) + if acc > best_acc: + saver.save(sess, os.path.join('runs', args.name, 'model_best')) + if (acc - args.minimum_gain) > best_acc: + best_epoch = epoch + + best_acc = acc + + epoch += 1 + + return best_acc, best_epoch + +if __name__ == '__main__': + main() diff --git a/model.py b/model.py new file mode 100644 index 0000000..2b6abca --- /dev/null +++ b/model.py @@ -0,0 +1,115 @@ +import tensorflow as tf + +from tensorflow.contrib.layers.python.layers import batch_norm +from tensorflow.contrib.layers.python.layers import convolution2d +from tensorflow.contrib.layers.python.layers import fully_connected +from tensorflow.contrib.layers.python.layers import l2_regularizer + +def add_fc(x, outdim, train_phase_plh, scope_in): + """Returns the output of a FC-BNORM-ReLU sequence. + + Arguments: + x -- input tensor + outdim -- desired output dimensions + train_phase_plh -- indicator whether model is in training mode + scope_in -- scope prefix for the desired layers + """ + l2_reg = tf.contrib.layers.l2_regularizer(0.0005) + fc = tf.contrib.layers.fully_connected(x, outdim, activation_fn = None, + weights_regularizer = l2_reg, + scope = scope_in + '/fc') + fc_bnorm = batch_norm_layer(fc, train_phase_plh, scope_in + '/bnorm') + return tf.nn.relu(fc_bnorm, scope_in + '/relu') + +def concept_layer(x, outdim, train_phase_plh, concept_id, weights): + """Returns the weighted value of a fully connected layer. + + Arguments: + x -- input tensor + outdim -- desired output dimensions + train_phase_plh -- indicator whether model is in training mode + concept_id -- identfier for the desired concept layer + weights -- vector of weights to be applied the concept outputs + """ + concept = add_fc(x, outdim, train_phase_plh, 'concept_%i' % concept_id) + concept = tf.reshape(concept, [tf.shape(concept)[0], -1]) + weighted_concept = concept * tf.expand_dims(weights[:, concept_id-1], 1) + return weighted_concept + +def batch_norm_layer(x, train_phase, scope_bn): + """Returns the output of a batch norm layer.""" + bn = tf.contrib.layers.batch_norm(x, decay=0.99, center=True, scale=True, + is_training=train_phase, + reuse=None, + trainable=True, + updates_collections=None, + scope=scope_bn) + return bn + +def embedding_branch(x, embed_dim, train_phase_plh, scope_in, do_l2norm = True, outdim = None): + """Applies a pair of fully connected layers to the input tensor. + + Arguments: + x -- input_tensor + embed_dim -- dimension of the input to the second fully connected layer + train_phase_plh -- indicator whether model is in training mode + scope_in -- scope prefix for the desired layers + do_l2norm -- indicates if the output should be l2 normalized + outdim -- dimension of the output embedding, if None outdim=embed_dim + """ + embed_fc1 = add_fc(x, embed_dim, train_phase_plh, scope_in + '_embed_1') + if outdim is None: + outdim = embed_dim + + l2_reg = tf.contrib.layers.l2_regularizer(0.001) + embed_fc2 = fully_connected(embed_fc1, outdim, activation_fn = None, + weights_regularizer = l2_reg, + scope = scope_in + '_embed_2') + if do_l2norm: + embed_fc2 = tf.nn.l2_normalize(embed_fc2, 1) + + return embed_fc2 + +def setup_model(args, phrase_plh, region_plh, train_phase_plh, labels_plh, num_boxes_plh, region_feature_dim): + """Describes the computational graph and returns the losses and outputs. + + Arguments: + args -- command line arguments passed into the main function + phrase_plh -- tensor containing the phrase features + region_plh -- tensor containing the region features + train_phase_plh -- indicator whether model is in training mode + labels_plh -- indicates positive (1), negative (-1), or ignore (0) + num_boxes_plh -- number of boxes per example in the batch + region_feature_dim -- dimensions of the region features + """ + labels_plh = tf.reshape(labels_plh, [-1, num_boxes_plh]) + eb_fea_plh = tf.reshape(region_plh, [-1, num_boxes_plh, region_feature_dim]) + + final_embed = args.dim_embed + embed_dim = final_embed * 4 + phrase_embed = embedding_branch(phrase_plh, embed_dim, train_phase_plh, 'phrase') + region_embed = embedding_branch(region_plh, embed_dim, train_phase_plh, 'region') + concept_weights = embedding_branch(phrase_plh, embed_dim, train_phase_plh, 'concept_weight', + do_l2norm = False, outdim = args.num_embeddings) + concept_loss = tf.reduce_mean(tf.norm(concept_weights, axis=1, ord=1)) + concept_weights = tf.nn.softmax(concept_weights) + + elementwise_prod = tf.expand_dims(phrase_embed, 1)*region_embed + joint_embed_1 = add_fc(elementwise_prod, embed_dim, train_phase_plh, 'joint_embed_1') + joint_embed_2 = concept_layer(joint_embed_1, final_embed, train_phase_plh, 1, concept_weights) + for concept_id in range(2, args.num_embeddings+1): + joint_embed_2 += concept_layer(joint_embed_1, final_embed, train_phase_plh, + concept_id, concept_weights) + + joint_embed_2 = tf.reshape(joint_embed_2, [tf.shape(joint_embed_2)[0], num_boxes_plh, final_embed]) + joint_embed_3 = fully_connected(joint_embed_2, 1, activation_fn=None , + weights_regularizer = l2_regularizer(0.005), + scope = 'joint_embed_3') + joint_embed_3 = tf.squeeze(joint_embed_3, [2]) + region_prob = 1. / (1. + tf.exp(-joint_embed_3)) + + ind_labels = tf.abs(labels_plh) + num_samples = tf.reduce_sum(ind_labels) + region_loss = tf.reduce_sum(tf.log(1+tf.exp(-joint_embed_3*labels_plh))*ind_labels)/num_samples + total_loss = region_loss + concept_loss * args.embed_l1 + return total_loss, region_loss, concept_loss, region_prob