Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

[WIP] Visual Genome Dataset #228

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions chainercv/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from chainercv.datasets.cub.cub_utils import cub_label_names # NOQA
from chainercv.datasets.online_products.online_products_dataset import OnlineProductsDataset # NOQA
from chainercv.datasets.transform_dataset import TransformDataset # NOQA
from chainercv.datasets.visual_genome.visual_genome_region_descriptions_dataset import VisualGenomeRegionDescriptionsDataset # NOQA
from chainercv.datasets.visual_genome.visual_genome_utils import VisualGenomeDatasetBase # NOQA
from chainercv.datasets.voc.voc_detection_dataset import VOCDetectionDataset # NOQA
from chainercv.datasets.voc.voc_semantic_segmentation_dataset import VOCSemanticSegmentationDataset # NOQA
from chainercv.datasets.voc.voc_utils import voc_detection_label_names # NOQA
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
from collections import Counter
from collections import defaultdict
import json
import os
import pickle
import six
import string

from chainer.dataset import download
import numpy as np

from chainercv.datasets.visual_genome.visual_genome_utils import \
get_region_descriptions
from chainercv.datasets.visual_genome.visual_genome_utils import root
from chainercv.datasets.visual_genome.visual_genome_utils import \
VisualGenomeDatasetBase


def get_vocabulary(region_descriptions='auto', min_token_instances=15):
"""Creates a vocabulary based on the region descriptions of Visual Genome.

A vocabulary is a dictionary that maps each word (str) to its
corresponding id (int). Rare words are treated as unknown words, i.e.
<unk> and are excluded for the dictionary.

Args:
min_token_instances (int): When words appear less than this times, they
will be treated as <unk>.

Returns:
dict: A dictionary mapping words to their corresponding ids.

"""
if region_descriptions == 'auto':
region_descriptions = get_region_descriptions()
return _create_word_vocabulary(region_descriptions, min_token_instances)


class VisualGenomeRegionDescriptionsDataset(VisualGenomeDatasetBase):
"""Region description class for Visual Genome dataset.

"""

def __init__(self, data_dir='auto', image_data='auto',
region_descriptions='auto', min_token_instances=15,
max_token_length=15, img_size=(720, 720)):
super(VisualGenomeRegionDescriptionsDataset, self).__init__(
data_dir=data_dir, image_data=image_data)

if region_descriptions == 'auto':
region_descriptions = get_region_descriptions()

self.img_size = img_size

self.region_ids = _get_region_ids(region_descriptions)
self.regions = _get_regions(region_descriptions)
self.phrases = _get_phrases(region_descriptions, min_token_instances,
max_token_length)

def get_example(self, i):
img_id = self.get_image_id(i)
img = self.get_image(img_id)

regions = []
phrases = []
for region_id in self.region_ids[img_id]:
# Phrases that are too long are excluded in the preprocessing,
# so only include regions with actual phrases
if region_id in self.phrases:
regions.append(self.regions[region_id])
phrases.append(self.phrases[region_id])
regions = np.vstack(regions).astype(np.float32)
phrases = np.vstack(phrases).astype(np.int32)

return img, regions, phrases


def _get_region_ids(region_descriptions_path):
"""Image ID (int) -> Region IDs (list of int).

"""
data_root = download.get_dataset_directory(root)
base_path = os.path.join(data_root, 'region_ids.pkl')

def creator(path):
print('Caching Visual Genome region IDs...')
region_ids = defaultdict(list)
with open(region_descriptions_path) as f:
region_descriptions = json.load(f)
for region_description in region_descriptions:
for region in region_description['regions']:
img_id = region['image_id']
region_id = region['region_id']
region_ids[img_id].append(region_id)
pickle.dump(dict(region_ids), open(base_path, 'wb'))
return region_ids

def loader(path):
return pickle.load(open(base_path, 'rb'))

return download.cache_or_load_file(base_path, creator, loader)


def _get_regions(region_descriptions_path):
"""Region ID (int) -> Region bounding box (xmin, ymin).

"""
data_root = download.get_dataset_directory(root)
base_path = os.path.join(data_root, 'regions.pkl')

def creator(path):
print('Caching Visual Genome region bounding boxes...')
regions = {}
with open(region_descriptions_path) as f:
region_descriptions = json.load(f)
for region_description in region_descriptions:
for region in region_description['regions']:
region_id = region['region_id'] # int
x_min = region['x']
y_min = region['y']
x_max = x_min + region['width']
y_max = y_min + region['height']
regions[region_id] = (y_min, x_min, y_max, x_max)
pickle.dump(regions, open(base_path, 'wb'))
return regions

def loader(path):
return pickle.load(open(base_path, 'rb'))

return download.cache_or_load_file(base_path, creator, loader)


def _get_phrases(region_descriptions_path, min_token_instances,
max_token_length):
"""Region ID (int) -> Phrase (list of int).

"""
data_root = download.get_dataset_directory(root)
base_path = os.path.join(data_root,
'phrases_{}.pkl'.format(min_token_instances))

def creator(path):
print('Caching Visual Genome region descriptions...')
phrases = {}
vocab = _create_word_vocabulary(region_descriptions_path,
min_token_instances)
with open(region_descriptions_path) as f:
region_descriptions = json.load(f)
for region_description in region_descriptions:
for region in region_description['regions']:
region_id = region['region_id']
tokens = _preprocess_phrase(region['phrase']).split()
if max_token_length > 0 and \
len(tokens) < max_token_length - 1:
# <bos>, t1, t2,..., tn, <eos>, <eos>,..., <eos>
phrase = np.empty(max_token_length, dtype=np.int32)
phrase.fill(vocab['<eos>'])
phrase[0] = vocab['<bos>']
for i, token in enumerate(tokens, 1):
if token not in vocab:
token = '<unk>'
token_id = vocab[token]
phrase[i] = token_id
phrases[region_id] = phrase

pickle.dump(phrases, open(base_path, 'wb'))
return phrases

def loader(path):
return pickle.load(open(base_path, 'rb'))

return download.cache_or_load_file(base_path, creator, loader)


def _create_word_vocabulary(region_descriptions_path, min_token_instances):
"""Word (str) -> Word ID (int).

"""
data_root = download.get_dataset_directory(root)
base_path = os.path.join(data_root,
'vocab_{}.txt'.format(min_token_instances))

def creator(path):
print('Creating vocabulary from region descriptions (ignoring words '
'that appear less than {} times)...'.format(min_token_instances))
words = _load_words(region_descriptions_path,
min_token_instances=min_token_instances)
vocab = {}
index = 0
with open(path, 'w') as f:
for word in words:
if word not in vocab:
vocab[word] = index
index += 1
f.write(word + '\n')
return vocab

def loader(path):
vocab = {}
with open(path) as f:
for i, word in enumerate(f):
vocab[word.strip()] = i
return vocab

return download.cache_or_load_file(base_path, creator, loader)


def _load_words(region_descriptions_path, min_token_instances):
# Count the number of occurrences for each word in all region descriptions
# to only include those words that appear at least a few times
word_counts = Counter()
with open(region_descriptions_path) as f:
region_descriptions = json.load(f)
for region_description in region_descriptions:
for region in region_description['regions']:
for word in _preprocess_phrase(region['phrase']).split():
word_counts[word] += 1

words = []
for word, count in six.iteritems(word_counts):
if min_token_instances is None or count >= min_token_instances:
words.append(word)

# Sort to make sure that word orders are consistent
words = sorted(words)

words.insert(0, '<unk>')
words.insert(0, '<eos>')
words.insert(0, '<bos>')

return words


def _preprocess_phrase(phrase):
# Preprocess phrases according to the DenseCap implementation
# https://github.com/jcjohnson/densecap/blob/master/preprocess.py
replacements = {
u'\xa2': u'cent',
u'\xb0': u' degree',
u'\xbd': u'half',
u'\xe7': u'c',
u'\xe8': u'e',
u'\xe9': u'e',
u'\xfb': u'u',
u'\u2014': u'-',
u'\u2026': u'',
u'\u2122': u'',
}

for k, v in six.iteritems(replacements):
phrase = phrase.replace(k, v)

trans = str.maketrans('', '', string.punctuation)
return str(phrase).lower().translate(trans)
110 changes: 110 additions & 0 deletions chainercv/datasets/visual_genome/visual_genome_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import json
import os

import chainer
from chainer.dataset import download

from chainercv import utils

root = 'pfnet/chainer/visual_genome'

vg_100k_url = 'https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip'
vg_100k_2_url = 'https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip'
image_data_url = 'http://visualgenome.org/static/data/dataset/' \
'image_data.json.zip'
region_descriptions_url = 'http://visualgenome.org/static/data/dataset/' \
'region_descriptions.json.zip'


def get_visual_genome():
"""Get the default path to the Visual Genome image directory.

Returns:
str: A path to the image directory.

"""
def move_files(src_dir, dst_dir):
# Move all files in the src_dir to the dst_dir and remove the src_dir
for f in os.listdir(src_dir):
src = os.path.join(src_dir, f)
if os.path.isfile(src):
dst = os.path.join(dst_dir, f)
os.rename(src, dst)
os.rmdir(src_dir)

data_root = download.get_dataset_directory(root)
base_path = os.path.join(data_root, 'VG_100K_ALL')

if os.path.exists(base_path):
return base_path

print('Caching Visual Genome image files...')
os.mkdir(base_path)
move_files(_get_extract_data(vg_100k_url, data_root, 'VG_100K'), base_path)
move_files(_get_extract_data(vg_100k_2_url, data_root, 'VG_100K_2'),
base_path)

return base_path


def get_image_data():
"""Get the default path to the image data JSON file.

Returns:
str: A path to the image data JSON file.

"""
data_root = download.get_dataset_directory(root)
return _get_extract_data(image_data_url, data_root, 'image_data.json')


def get_region_descriptions():
"""Get the default path to the region descriptions JSON file.

Returns:
str: A path to the region descriptions JSON file.

"""
data_root = download.get_dataset_directory(root)
return _get_extract_data(region_descriptions_url, data_root,
'region_descriptions.json')


class VisualGenomeDatasetBase(chainer.dataset.DatasetMixin):
"""Base class for Visual Genome dataset.

"""

def __init__(self, data_dir='auto', image_data='auto'):
if data_dir == 'auto':
data_dir = get_visual_genome()
if image_data == 'auto':
image_data = get_image_data()
self.data_dir = data_dir

with open(image_data, 'r') as f:
img_ids = [img_data['image_id'] for img_data in json.load(f)]
self.img_ids = sorted(img_ids)

def __len__(self):
return len(self.img_ids)

def get_image_id(self, i):
return self.img_ids[i]

def get_image(self, img_id):
img_path = os.path.join(self.data_dir, str(img_id) + '.jpg')
return utils.read_image(img_path, color=True)


def _get_extract_data(url, data_root, member_path):
base_path = os.path.join(data_root, member_path)

if os.path.exists(base_path):
return base_path

download_file_path = utils.cached_download(url)
ext = os.path.splitext(url)[1]
utils.extractall(download_file_path, data_root, ext)

return base_path