diff --git a/init2winit/dataset_lib/datasets.py b/init2winit/dataset_lib/datasets.py index 319b211c..590e2ff1 100644 --- a/init2winit/dataset_lib/datasets.py +++ b/init2winit/dataset_lib/datasets.py @@ -24,6 +24,7 @@ from init2winit.dataset_lib import librispeech from init2winit.dataset_lib import lm1b_v2 from init2winit.dataset_lib import mlperf_imagenet_dataset +from init2winit.dataset_lib import mlperf_imagenet_dataset_grain from init2winit.dataset_lib import nqm_noise from init2winit.dataset_lib import ogbg_molpcba from init2winit.dataset_lib import proteins @@ -84,6 +85,10 @@ mlperf_imagenet_dataset.DEFAULT_HPARAMS, mlperf_imagenet_dataset.METADATA, mlperf_imagenet_dataset.get_fake_batch), + 'mlperf_imagenet_grain': + _Dataset(mlperf_imagenet_dataset_grain.get_mlperf_imagenet, + mlperf_imagenet_dataset_grain.DEFAULT_HPARAMS, + mlperf_imagenet_dataset_grain.METADATA, None), 'svhn_no_extra': _Dataset(small_image_datasets.get_svhn_no_extra, small_image_datasets.SVHN_NO_EXTRA_DEFAULT_HPARAMS, diff --git a/init2winit/dataset_lib/mlperf_imagenet_dataset_grain.py b/init2winit/dataset_lib/mlperf_imagenet_dataset_grain.py new file mode 100644 index 00000000..162bef41 --- /dev/null +++ b/init2winit/dataset_lib/mlperf_imagenet_dataset_grain.py @@ -0,0 +1,134 @@ +# coding=utf-8 +# Copyright 2022 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ImageNet input pipeline with MLPerf preprocessing.""" + +import itertools + +from init2winit.dataset_lib import data_utils +from init2winit.dataset_lib import mlperf_input_pipeline_grain +import jax +from ml_collections.config_dict import config_dict +import numpy as np +import tensorflow.compat.v2 as tf + +NUM_CLASSES = 1000 + +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + input_shape=(224, 224, 3), + output_shape=(NUM_CLASSES,), + train_size=1281167, + valid_size=50000)) + +METADATA = { + 'apply_one_hot_in_loss': False, +} + + +def get_mlperf_imagenet(rng, batch_size, eval_batch_size, hps=None): + """Data generators for imagenet. + + Args: + rng: RNG seed that is split into a shuffle seed and a seed that is folded + into a per-example seed. + batch_size: the *global* batch size used for training. + eval_batch_size: the *global* batch size used for evaluation. + hps: the hparams for the experiment, only required field is valid_size. + + Returns: + A data_utils.Dataset for the MLPerf version of ImageNet. + """ + if batch_size % jax.device_count() != 0: + raise ValueError('Require batch_size % jax.device_count(), received ' + 'batch_size={}, device_count={}.'.format( + batch_size, jax.device_count())) + if eval_batch_size % jax.device_count() != 0: + raise ValueError('Require eval_batch_size % jax.device_count(), received ' + 'eval_batch_size={}, device_count={}.'.format( + eval_batch_size, jax.device_count())) + host_batch_size = batch_size // jax.process_count() + eval_host_batch_size = eval_batch_size // jax.process_count() + + max_eval_steps = hps.valid_size // eval_batch_size + 1 + + input_dtype = tf.bfloat16 + + train_ds = mlperf_input_pipeline_grain.load_split( + host_batch_size, + dtype=input_dtype, + split='train', + rng=rng, + preprocess_transform=mlperf_input_pipeline_grain.NormalizeAndOneHot(), + is_train=True) + + eval_train_ds = mlperf_input_pipeline_grain.load_split( + eval_host_batch_size, + dtype=input_dtype, + split='train', + rng=rng, + preprocess_transform=mlperf_input_pipeline_grain.NormalizeAndOneHot()) + + eval_ds = mlperf_input_pipeline_grain.load_split( + eval_host_batch_size, + dtype=input_dtype, + split='validation', + rng=rng, + preprocess_transform=mlperf_input_pipeline_grain.NormalizeAndOneHot()) + + # TODO(mbadura): Add test dataset like in mlperf_imagenet_dataset + + # We cannot use tfds.as_numpy because this calls tensor.numpy() which does an + # additional copy of the tensor, instead we call tensor._numpy() below. + def train_iterator_fn(): + return data_utils.iterator_as_numpy(iter(train_ds)) + + def eval_train_epoch(num_batches=None): + if num_batches is None: + num_batches = 0 + eval_train_iter = iter(eval_train_ds) + + for batch in data_utils.iterator_as_numpy( + itertools.islice(eval_train_iter, num_batches)): + yield data_utils.maybe_pad_batch(batch, eval_host_batch_size) + + def valid_epoch(num_batches=None): + if num_batches is None: + num_batches = max_eval_steps + valid_iter = iter(eval_ds) + for batch in data_utils.iterator_as_numpy( + itertools.islice(valid_iter, num_batches)): + yield data_utils.maybe_pad_batch(batch, eval_host_batch_size) + + def test_epoch(num_batches=None): + # pylint: disable=unreachable + del num_batches + return + yield # This yield is needed to make this a valid (null) iterator. + # pylint: enable=unreachable + + return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, + test_epoch) + + +def get_fake_batch(hps): + return { + 'inputs': + np.ones((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype), + 'targets': + np.ones((hps.batch_size, *hps.output_shape), dtype=hps.model_dtype), + 'weights': + np.ones((hps.batch_size,), dtype=hps.model_dtype), + } diff --git a/init2winit/dataset_lib/mlperf_input_pipeline_grain.py b/init2winit/dataset_lib/mlperf_input_pipeline_grain.py new file mode 100644 index 00000000..7455cf82 --- /dev/null +++ b/init2winit/dataset_lib/mlperf_input_pipeline_grain.py @@ -0,0 +1,222 @@ +# coding=utf-8 +# Copyright 2022 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Efficient ImageNet input pipeline using tf.data.Dataset.""" + +import dataclasses +from typing import Any, Sequence + +from clu import preprocess_spec +from grain._src.tensorflow import transforms +import grain.tensorflow as grain +import jax +import tensorflow as tf +import tensorflow_datasets as tfds + +EVAL_IMAGES = 50000 +CROP_PADDING = 32 +MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] +STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] +NUM_CLASSES = 1000 + +FlatFeatures = preprocess_spec.FlatFeatures + + +def transpose_and_normalize_image(image): + mean = tf.constant([[MEAN_RGB]], dtype=image.dtype) + stddev = tf.constant([[STDDEV_RGB]], dtype=image.dtype) + image -= mean + image /= stddev + return image + + +# Note that this will run before batching. +@dataclasses.dataclass(frozen=True) +class NormalizeAndOneHot(preprocess_spec.MapTransform): + + def _transform(self, features: FlatFeatures) -> FlatFeatures: + features['inputs'] = transpose_and_normalize_image(features['image']) + features['targets'] = tf.one_hot(features['label'], NUM_CLASSES) + del features['label'] + del features['image'] + return features + + +@dataclasses.dataclass(frozen=True) +class DecodeRandomCropAndResize(preprocess_spec.RandomMapTransform): + """Decodes the images and extracts a random crop.""" + + resize_size: int + + def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatFeatures: + image = features['image'] + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + bbox_begin, bbox_size, _ = tf.image.stateless_sample_distorted_bounding_box( + tf.image.extract_jpeg_shape(image), + bbox, + seed=seed, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=10, + use_image_if_no_bounding_boxes=True) + + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + + image = tf.image.decode_and_crop_jpeg(image, crop_window, channels=3) + image = tf.image.resize([image], [self.resize_size, self.resize_size], + method='bicubic')[0] + features['image'] = image + return features + + +@dataclasses.dataclass(frozen=True) +class CentralCropAndResize(preprocess_spec.MapTransform): + """Makes a central crop of a given size.""" + + resize_size: int + + def _transform(self, features: FlatFeatures) -> FlatFeatures: + image = features['image'] + shape = tf.image.extract_jpeg_shape(image) + crop_size = tf.cast( + ((self.resize_size / (self.resize_size + CROP_PADDING)) * + tf.cast(tf.minimum(shape[0], shape[1]), tf.float32)), tf.int32) + offset_y, offset_x = [((shape[i] - crop_size) + 1) // 2 for i in range(2)] + crop_window = tf.stack([offset_y, offset_x, crop_size, crop_size]) + image = tf.image.decode_and_crop_jpeg(image, crop_window, channels=3) + image = tf.image.resize([image], [self.resize_size, self.resize_size], + method='bicubic')[0] + features['image'] = image + return features + + +@dataclasses.dataclass(frozen=True) +class RandomFlipLeftRight(preprocess_spec.RandomMapTransform): + + def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatFeatures: + features['image'] = tf.image.stateless_random_flip_left_right( + features['image'], seed) + return features + + +@dataclasses.dataclass(frozen=True) +class ReshapeAndConvertDtype(preprocess_spec.MapTransform): + """Reshapes and converts to a given dtype.""" + + resize_size: int + dtype: Any + + def _transform(self, features: FlatFeatures) -> FlatFeatures: + image = features['image'] + image = tf.reshape(image, [self.resize_size, self.resize_size, 3]) + image = tf.image.convert_image_dtype(image, self.dtype) + features['image'] = image + return features + + +@dataclasses.dataclass(frozen=True) +class DropFeatures(preprocess_spec.MapTransform): + + feature_names: Sequence[str] + + def _transform(self, features: FlatFeatures) -> FlatFeatures: + return {k: v for k, v in features.items() if k not in self.feature_names} + + +def load_split(batch_size, + split, + dtype, + rng=None, + data_dir=None, + image_size=224, + preprocess_transform=None, + is_train=False): + """Returns the input_fn. + + Args: + batch_size: the batch size to use for `dataset.batch()`. + split: a string representing the dataset split to use, either 'train', + 'eval_train', or 'validation'. + dtype: the dtype of the image in the data pipeline. + rng: RNG seed that is split into a shuffle seed and a seed that is folded + into a per-example seed. + data_dir: an optional string to read the ImageNet TFRecord files from. + image_size: the size to resize the images to using `tf.image.resize(..., + method='bicubic')`. + preprocess_transform: a function that takes two arguments (a single image + and a single label) and returns a dict with keys 'inputs' and 'labels'. + is_train: True if loading the train dataset which will be infinite, + so it can't be cached. + Returns: a tf.data.Dataset that is batched and preprocessed, and optionally + shuffled and repeated, for ImageNet based off the MLPerf codebase. Note that + for evaluation, the final partial batches are not yet padded to be the same + shape, so callers should also call `data_utils.maybe_pad_batch(batch, + eval_host_batch_size)` to pad them. + """ + if split not in ['train', 'validation', 'test']: + raise ValueError('Invalid split name {}.'.format(split)) + + initial_step = 1 + + grain.config.update('tf_interleaved_shuffle', True) + + # The init2winit convention is to specify batch_size per CPU host + global_batch_size = jax.process_count() * batch_size + if split == 'train': + start_index = (initial_step - 1) * global_batch_size + jax.process_index() + else: + start_index = jax.process_index() + + if split == 'train': + # Tell TFDS to not decode the image as we combined it with the random crop. + decoders = {'image': tfds.decode.SkipDecoding()} + transformations = [] if is_train else [transforms.CacheTransform()] + transformations += [ + DecodeRandomCropAndResize(resize_size=image_size), + RandomFlipLeftRight(), + ReshapeAndConvertDtype(resize_size=image_size, dtype=dtype), + DropFeatures(('file_name',)) + ] + else: + decoders = {'image': tfds.decode.SkipDecoding()} + transformations = [ + CentralCropAndResize(224), + ReshapeAndConvertDtype(resize_size=image_size, dtype=dtype), + DropFeatures(('file_name',)), + transforms.CacheTransform() + ] + + if preprocess_transform is not None: + transformations.append(preprocess_transform) + + data_dir = tfds.core.constants.ARRAY_RECORD_DATA_DIR + + loader = grain.load_from_tfds( + name='imagenet2012', + data_dir=data_dir, + split=split, + shuffle=True, + seed=rng, + shard_options=grain.ShardByJaxProcess(drop_remainder=split == 'train'), + decoders=decoders, + transformations=transformations, + num_epochs=None if is_train else 1, + batch_size=batch_size) + + dataset_iter = loader.as_dataset(start_index=start_index) + return dataset_iter