Skip to content

Commit

Permalink
Add ImageNet Grain dataset
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 471253754
  • Loading branch information
michbad authored and copybara-github committed Nov 3, 2022
1 parent 5011286 commit 2f325bf
Show file tree
Hide file tree
Showing 3 changed files with 361 additions and 0 deletions.
5 changes: 5 additions & 0 deletions init2winit/dataset_lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from init2winit.dataset_lib import librispeech
from init2winit.dataset_lib import lm1b_v2
from init2winit.dataset_lib import mlperf_imagenet_dataset
from init2winit.dataset_lib import mlperf_imagenet_dataset_grain
from init2winit.dataset_lib import nqm_noise
from init2winit.dataset_lib import ogbg_molpcba
from init2winit.dataset_lib import proteins
Expand Down Expand Up @@ -84,6 +85,10 @@
mlperf_imagenet_dataset.DEFAULT_HPARAMS,
mlperf_imagenet_dataset.METADATA,
mlperf_imagenet_dataset.get_fake_batch),
'mlperf_imagenet_grain':
_Dataset(mlperf_imagenet_dataset_grain.get_mlperf_imagenet,
mlperf_imagenet_dataset_grain.DEFAULT_HPARAMS,
mlperf_imagenet_dataset_grain.METADATA, None),
'svhn_no_extra':
_Dataset(small_image_datasets.get_svhn_no_extra,
small_image_datasets.SVHN_NO_EXTRA_DEFAULT_HPARAMS,
Expand Down
134 changes: 134 additions & 0 deletions init2winit/dataset_lib/mlperf_imagenet_dataset_grain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# coding=utf-8
# Copyright 2022 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ImageNet input pipeline with MLPerf preprocessing."""

import itertools

from init2winit.dataset_lib import data_utils
from init2winit.dataset_lib import mlperf_input_pipeline_grain
import jax
from ml_collections.config_dict import config_dict
import numpy as np
import tensorflow.compat.v2 as tf

NUM_CLASSES = 1000

DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
input_shape=(224, 224, 3),
output_shape=(NUM_CLASSES,),
train_size=1281167,
valid_size=50000))

METADATA = {
'apply_one_hot_in_loss': False,
}


def get_mlperf_imagenet(rng, batch_size, eval_batch_size, hps=None):
"""Data generators for imagenet.
Args:
rng: RNG seed that is split into a shuffle seed and a seed that is folded
into a per-example seed.
batch_size: the *global* batch size used for training.
eval_batch_size: the *global* batch size used for evaluation.
hps: the hparams for the experiment, only required field is valid_size.
Returns:
A data_utils.Dataset for the MLPerf version of ImageNet.
"""
if batch_size % jax.device_count() != 0:
raise ValueError('Require batch_size % jax.device_count(), received '
'batch_size={}, device_count={}.'.format(
batch_size, jax.device_count()))
if eval_batch_size % jax.device_count() != 0:
raise ValueError('Require eval_batch_size % jax.device_count(), received '
'eval_batch_size={}, device_count={}.'.format(
eval_batch_size, jax.device_count()))
host_batch_size = batch_size // jax.process_count()
eval_host_batch_size = eval_batch_size // jax.process_count()

max_eval_steps = hps.valid_size // eval_batch_size + 1

input_dtype = tf.bfloat16

train_ds = mlperf_input_pipeline_grain.load_split(
host_batch_size,
dtype=input_dtype,
split='train',
rng=rng,
preprocess_transform=mlperf_input_pipeline_grain.NormalizeAndOneHot(),
is_train=True)

eval_train_ds = mlperf_input_pipeline_grain.load_split(
eval_host_batch_size,
dtype=input_dtype,
split='train',
rng=rng,
preprocess_transform=mlperf_input_pipeline_grain.NormalizeAndOneHot())

eval_ds = mlperf_input_pipeline_grain.load_split(
eval_host_batch_size,
dtype=input_dtype,
split='validation',
rng=rng,
preprocess_transform=mlperf_input_pipeline_grain.NormalizeAndOneHot())

# TODO(mbadura): Add test dataset like in mlperf_imagenet_dataset

# We cannot use tfds.as_numpy because this calls tensor.numpy() which does an
# additional copy of the tensor, instead we call tensor._numpy() below.
def train_iterator_fn():
return data_utils.iterator_as_numpy(iter(train_ds))

def eval_train_epoch(num_batches=None):
if num_batches is None:
num_batches = 0
eval_train_iter = iter(eval_train_ds)

for batch in data_utils.iterator_as_numpy(
itertools.islice(eval_train_iter, num_batches)):
yield data_utils.maybe_pad_batch(batch, eval_host_batch_size)

def valid_epoch(num_batches=None):
if num_batches is None:
num_batches = max_eval_steps
valid_iter = iter(eval_ds)
for batch in data_utils.iterator_as_numpy(
itertools.islice(valid_iter, num_batches)):
yield data_utils.maybe_pad_batch(batch, eval_host_batch_size)

def test_epoch(num_batches=None):
# pylint: disable=unreachable
del num_batches
return
yield # This yield is needed to make this a valid (null) iterator.
# pylint: enable=unreachable

return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
test_epoch)


def get_fake_batch(hps):
return {
'inputs':
np.ones((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype),
'targets':
np.ones((hps.batch_size, *hps.output_shape), dtype=hps.model_dtype),
'weights':
np.ones((hps.batch_size,), dtype=hps.model_dtype),
}
222 changes: 222 additions & 0 deletions init2winit/dataset_lib/mlperf_input_pipeline_grain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# coding=utf-8
# Copyright 2022 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Efficient ImageNet input pipeline using tf.data.Dataset."""

import dataclasses
from typing import Any, Sequence

from clu import preprocess_spec
from grain._src.tensorflow import transforms
import grain.tensorflow as grain
import jax
import tensorflow as tf
import tensorflow_datasets as tfds

EVAL_IMAGES = 50000
CROP_PADDING = 32
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
NUM_CLASSES = 1000

FlatFeatures = preprocess_spec.FlatFeatures


def transpose_and_normalize_image(image):
mean = tf.constant([[MEAN_RGB]], dtype=image.dtype)
stddev = tf.constant([[STDDEV_RGB]], dtype=image.dtype)
image -= mean
image /= stddev
return image


# Note that this will run before batching.
@dataclasses.dataclass(frozen=True)
class NormalizeAndOneHot(preprocess_spec.MapTransform):

def _transform(self, features: FlatFeatures) -> FlatFeatures:
features['inputs'] = transpose_and_normalize_image(features['image'])
features['targets'] = tf.one_hot(features['label'], NUM_CLASSES)
del features['label']
del features['image']
return features


@dataclasses.dataclass(frozen=True)
class DecodeRandomCropAndResize(preprocess_spec.RandomMapTransform):
"""Decodes the images and extracts a random crop."""

resize_size: int

def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatFeatures:
image = features['image']
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
bbox_begin, bbox_size, _ = tf.image.stateless_sample_distorted_bounding_box(
tf.image.extract_jpeg_shape(image),
bbox,
seed=seed,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=10,
use_image_if_no_bounding_boxes=True)

offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])

image = tf.image.decode_and_crop_jpeg(image, crop_window, channels=3)
image = tf.image.resize([image], [self.resize_size, self.resize_size],
method='bicubic')[0]
features['image'] = image
return features


@dataclasses.dataclass(frozen=True)
class CentralCropAndResize(preprocess_spec.MapTransform):
"""Makes a central crop of a given size."""

resize_size: int

def _transform(self, features: FlatFeatures) -> FlatFeatures:
image = features['image']
shape = tf.image.extract_jpeg_shape(image)
crop_size = tf.cast(
((self.resize_size / (self.resize_size + CROP_PADDING)) *
tf.cast(tf.minimum(shape[0], shape[1]), tf.float32)), tf.int32)
offset_y, offset_x = [((shape[i] - crop_size) + 1) // 2 for i in range(2)]
crop_window = tf.stack([offset_y, offset_x, crop_size, crop_size])
image = tf.image.decode_and_crop_jpeg(image, crop_window, channels=3)
image = tf.image.resize([image], [self.resize_size, self.resize_size],
method='bicubic')[0]
features['image'] = image
return features


@dataclasses.dataclass(frozen=True)
class RandomFlipLeftRight(preprocess_spec.RandomMapTransform):

def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatFeatures:
features['image'] = tf.image.stateless_random_flip_left_right(
features['image'], seed)
return features


@dataclasses.dataclass(frozen=True)
class ReshapeAndConvertDtype(preprocess_spec.MapTransform):
"""Reshapes and converts to a given dtype."""

resize_size: int
dtype: Any

def _transform(self, features: FlatFeatures) -> FlatFeatures:
image = features['image']
image = tf.reshape(image, [self.resize_size, self.resize_size, 3])
image = tf.image.convert_image_dtype(image, self.dtype)
features['image'] = image
return features


@dataclasses.dataclass(frozen=True)
class DropFeatures(preprocess_spec.MapTransform):

feature_names: Sequence[str]

def _transform(self, features: FlatFeatures) -> FlatFeatures:
return {k: v for k, v in features.items() if k not in self.feature_names}


def load_split(batch_size,
split,
dtype,
rng=None,
data_dir=None,
image_size=224,
preprocess_transform=None,
is_train=False):
"""Returns the input_fn.
Args:
batch_size: the batch size to use for `dataset.batch()`.
split: a string representing the dataset split to use, either 'train',
'eval_train', or 'validation'.
dtype: the dtype of the image in the data pipeline.
rng: RNG seed that is split into a shuffle seed and a seed that is folded
into a per-example seed.
data_dir: an optional string to read the ImageNet TFRecord files from.
image_size: the size to resize the images to using `tf.image.resize(...,
method='bicubic')`.
preprocess_transform: a function that takes two arguments (a single image
and a single label) and returns a dict with keys 'inputs' and 'labels'.
is_train: True if loading the train dataset which will be infinite,
so it can't be cached.
Returns: a tf.data.Dataset that is batched and preprocessed, and optionally
shuffled and repeated, for ImageNet based off the MLPerf codebase. Note that
for evaluation, the final partial batches are not yet padded to be the same
shape, so callers should also call `data_utils.maybe_pad_batch(batch,
eval_host_batch_size)` to pad them.
"""
if split not in ['train', 'validation', 'test']:
raise ValueError('Invalid split name {}.'.format(split))

initial_step = 1

grain.config.update('tf_interleaved_shuffle', True)

# The init2winit convention is to specify batch_size per CPU host
global_batch_size = jax.process_count() * batch_size
if split == 'train':
start_index = (initial_step - 1) * global_batch_size + jax.process_index()
else:
start_index = jax.process_index()

if split == 'train':
# Tell TFDS to not decode the image as we combined it with the random crop.
decoders = {'image': tfds.decode.SkipDecoding()}
transformations = [] if is_train else [transforms.CacheTransform()]
transformations += [
DecodeRandomCropAndResize(resize_size=image_size),
RandomFlipLeftRight(),
ReshapeAndConvertDtype(resize_size=image_size, dtype=dtype),
DropFeatures(('file_name',))
]
else:
decoders = {'image': tfds.decode.SkipDecoding()}
transformations = [
CentralCropAndResize(224),
ReshapeAndConvertDtype(resize_size=image_size, dtype=dtype),
DropFeatures(('file_name',)),
transforms.CacheTransform()
]

if preprocess_transform is not None:
transformations.append(preprocess_transform)

data_dir = tfds.core.constants.ARRAY_RECORD_DATA_DIR

loader = grain.load_from_tfds(
name='imagenet2012',
data_dir=data_dir,
split=split,
shuffle=True,
seed=rng,
shard_options=grain.ShardByJaxProcess(drop_remainder=split == 'train'),
decoders=decoders,
transformations=transformations,
num_epochs=None if is_train else 1,
batch_size=batch_size)

dataset_iter = loader.as_dataset(start_index=start_index)
return dataset_iter

0 comments on commit 2f325bf

Please sign in to comment.