Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Robustness and Error Handling in ImageFolder Dataset Builder #5567

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
365 changes: 184 additions & 181 deletions tensorflow_datasets/core/folder_dataset/image_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,204 +40,207 @@


class ImageFolder(dataset_builder.DatasetBuilder):
"""Generic image classification dataset created from manual directory.

`ImageFolder` creates a `tf.data.Dataset` reading the original image files.

The data directory should have the following structure:

```
path/to/image_dir/
split_name/ # Ex: 'train'
label1/ # Ex: 'airplane' or '0015'
xxx.png
xxy.png
xxz.png
label2/
xxx.png
xxy.png
xxz.png
split_name/ # Ex: 'test'
...
```

To use it:

```
builder = tfds.ImageFolder('path/to/image_dir/')
print(builder.info) # num examples, labels... are automatically calculated
ds = builder.as_dataset(split='train', shuffle_files=True)
tfds.show_examples(ds, builder.info)
```
"""

VERSION = version.Version('1.0.0')

def __init__(
self,
root_dir: str,
*,
shape: Optional[type_utils.Shape] = None,
dtype: Optional[tf.DType] = None,
):
"""Construct the `DatasetBuilder`.

Args:
root_dir: Path to the directory containing the images.
shape: Image shape forwarded to `tfds.features.Image`.
dtype: Image dtype forwarded to `tfds.features.Image`.
"""Generic image classification dataset created from manual directory.

`ImageFolder` creates a `tf.data.Dataset` reading the original image files.

The data directory should have the following structure:

```
path/to/image_dir/
split_name/ # Ex: 'train'
label1/ # Ex: 'airplane' or '0015'
xxx.png
xxy.png
xxz.png
label2/
xxx.png
xxy.png
xxz.png
split_name/ # Ex: 'test'
...
```

To use it:

```
builder = tfds.ImageFolder('path/to/image_dir/')
print(builder.info) # num examples, labels... are automatically calculated
ds = builder.as_dataset(split='train', shuffle_files=True)
tfds.show_examples(ds, builder.info)
```
"""
self._image_shape = shape
self._image_dtype = dtype
super(ImageFolder, self).__init__()
self._data_dir = root_dir # Set data_dir to the existing dir.

# Extract the splits, examples, labels
root_dir = os.path.expanduser(root_dir)
self._split_examples, labels = _get_split_label_images(root_dir)

# Update DatasetInfo labels
self.info.features['label'].names = sorted(labels)

# Update DatasetInfo splits
split_infos = [
split_lib.SplitInfo( # pylint: disable=g-complex-comprehension
name=split_name,
shard_lengths=[len(examples)],
num_bytes=0,

VERSION = version.Version('1.0.0')

def __init__(
self,
root_dir: str,
*,
shape: Optional[type_utils.Shape] = None,
dtype: Optional[tf.DType] = None,
):
"""Construct the `DatasetBuilder`.

Args:
root_dir: Path to the directory containing the images.
shape: Image shape forwarded to `tfds.features.Image`.
dtype: Image dtype forwarded to `tfds.features.Image`.
"""
self._image_shape = shape
self._image_dtype = dtype
super(ImageFolder, self).__init__()
self._data_dir = root_dir # Set data_dir to the existing dir.

# Extract the splits, examples, labels
root_dir = os.path.expanduser(root_dir)
self._split_examples, labels = _get_split_label_images(root_dir)

# Update DatasetInfo labels
self.info.features['label'].names = sorted(labels)

# Update DatasetInfo splits
split_infos = [
split_lib.SplitInfo( # pylint: disable=g-complex-comprehension
name=split_name,
shard_lengths=[len(examples)],
num_bytes=0,
)
for split_name, examples in self._split_examples.items()
]
split_dict = split_lib.SplitDict(split_infos)
self.info.set_splits(split_dict)

def _info(self) -> dataset_info.DatasetInfo:
return dataset_info.DatasetInfo(
builder=self,
description='Generic image classification dataset.',
features=features_lib.FeaturesDict({
'image': features_lib.Image(
shape=self._image_shape,
dtype=self._image_dtype,
),
'label': features_lib.ClassLabel(),
'image/filename': features_lib.Text(),
}),
supervised_keys=('image', 'label'),
)
for split_name, examples in self._split_examples.items()
]
split_dict = split_lib.SplitDict(split_infos)
self.info.set_splits(split_dict)

def _info(self) -> dataset_info.DatasetInfo:
return dataset_info.DatasetInfo(
builder=self,
description='Generic image classification dataset.',
features=features_lib.FeaturesDict({
'image': features_lib.Image(
shape=self._image_shape,
dtype=self._image_dtype,
),
'label': features_lib.ClassLabel(),
'image/filename': features_lib.Text(),
}),
supervised_keys=('image', 'label'),
)

def _download_and_prepare(self, **kwargs) -> NoReturn: # pytype: disable=signature-mismatch # overriding-parameter-count-checks
raise NotImplementedError(
'No need to call download_and_prepare function for {}.'.format(
type(self).__name__

def _download_and_prepare(self, **kwargs) -> NoReturn: # pytype: disable=signature-mismatch # overriding-parameter-count-checks
raise NotImplementedError(
'No need to call download_and_prepare function for {}.'.format(
type(self).__name__
)
)

def download_and_prepare(self, **kwargs): # -> NoReturn:
return self._download_and_prepare()

def _as_dataset(
self,
split: str,
shuffle_files: bool = False,
decoders: Optional[Dict[str, decode.Decoder]] = None,
read_config=None,
) -> tf.data.Dataset:
"""Generate dataset for given split."""
del read_config # Unused (automatically created in `DatasetBuilder`)

if split not in self.info.splits.keys():
raise ValueError(
'Unrecognized split {}. Subsplit API not yet supported for {}. '
'Split name should be one of {}.'.format(
split, type(self).__name__, list(self.info.splits.keys())
)
)

# Extract all labels/images
image_paths = []
labels = []
examples = self._split_examples[split]
for example in examples:
image_paths.append(example.image_path)
labels.append(self.info.features['label'].str2int(example.label))

# Build the tf.data.Dataset object
ds = tf.data.Dataset.from_tensor_slices((image_paths, labels))
if shuffle_files:
ds = ds.shuffle(len(examples))

# Fuse load and decode into one function
def _load_and_decode_fn(*args, **kwargs):
ex = _load_example(*args, **kwargs)
return self.info.features.decode_example(ex, decoders=decoders)

ds = ds.map(
_load_and_decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
)

def download_and_prepare(self, **kwargs): # -> NoReturn:
return self._download_and_prepare()

def _as_dataset(
self,
split: str,
shuffle_files: bool = False,
decoders: Optional[Dict[str, decode.Decoder]] = None,
read_config=None,
) -> tf.data.Dataset:
"""Generate dataset for given split."""
del read_config # Unused (automatically created in `DatasetBuilder`)

if split not in self.info.splits.keys():
raise ValueError(
'Unrecognized split {}. Subsplit API not yet supported for {}. '
'Split name should be one of {}.'.format(
split, type(self).__name__, list(self.info.splits.keys())
)
)

# Extract all labels/images
image_paths = []
labels = []
examples = self._split_examples[split]
for example in examples:
image_paths.append(example.image_path)
labels.append(self.info.features['label'].str2int(example.label))

# Build the tf.data.Dataset object
ds = tf.data.Dataset.from_tensor_slices((image_paths, labels))
if shuffle_files:
ds = ds.shuffle(len(examples))

# Fuse load and decode into one function
def _load_and_decode_fn(*args, **kwargs):
ex = _load_example(*args, **kwargs)
return self.info.features.decode_example(ex, decoders=decoders)

ds = ds.map(
_load_and_decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
return ds
return ds


def _load_example(
path: tf.Tensor,
label: tf.Tensor,
) -> Dict[str, tf.Tensor]:
img = tf.io.read_file(path)
return {
'image': img,
'label': tf.cast(label, tf.int64),
'image/filename': path,
}
img = tf.io.read_file(path)
return {
'image': img,
'label': tf.cast(label, tf.int64),
'image/filename': path,
}


def _get_split_label_images(
root_dir: str,
) -> Tuple[SplitExampleDict, List[str]]:
"""Extract all label names and associated images.

This function guarantee that examples are deterministically shuffled
and labels are sorted.

Args:
root_dir: The folder where the `split/label/image.png` are located

Returns:
split_examples: Mapping split_names -> List[_Example]
labels: The list off labels
"""
split_examples = collections.defaultdict(list)
labels = set()
for split_name in sorted(_list_folders(root_dir)):
split_dir = os.path.join(root_dir, split_name)
for label_name in sorted(_list_folders(split_dir)):
labels.add(label_name)
split_examples[split_name].extend(
[
_Example(image_path=image_path, label=label_name)
for image_path in sorted(
_list_img_paths(os.path.join(split_dir, label_name))
)
]
)

# Shuffle the images deterministically
for split_name, examples in split_examples.items():
rgn = random.Random(split_name) # Uses different seed for each split
rgn.shuffle(examples)
return split_examples, sorted(labels)
"""Extract all label names and associated images.

This function guarantees that examples are deterministically shuffled
and labels are sorted.

Args:
root_dir: The folder where the `split/label/image.png` are located.

Returns:
split_examples: Mapping split_names -> List[_Example]
labels: The list of labels.
"""
if not tf.io.gfile.exists(root_dir):
raise ValueError(f"The provided root directory '{root_dir}' does not exist.")

split_examples = collections.defaultdict(list)
labels = set()
for split_name in sorted(_list_folders(root_dir)):
split_dir = os.path.join(root_dir, split_name)
for label_name in sorted(_list_folders(split_dir)):
labels.add(label_name)
split_examples[split_name].extend(
[
_Example(image_path=image_path, label=label_name)
for image_path in sorted(
_list_img_paths(os.path.join(split_dir, label_name))
)
]
)

# Shuffle the images deterministically
for split_name, examples in split_examples.items():
rgn = random.Random(hash(split_name)) # Use hash for more uniform seeding
rgn.shuffle(examples)
return split_examples, sorted(labels)


def _list_folders(root_dir: str) -> List[str]:
return [
f
for f in tf.io.gfile.listdir(root_dir)
if tf.io.gfile.isdir(os.path.join(root_dir, f))
]
return [
f
for f in tf.io.gfile.listdir(root_dir)
if tf.io.gfile.isdir(os.path.join(root_dir, f))
]


def _list_img_paths(root_dir: str) -> List[str]:
return [
os.path.join(root_dir, f)
for f in tf.io.gfile.listdir(root_dir)
if any(f.lower().endswith(ext) for ext in _SUPPORTED_IMAGE_FORMAT)
]
return [
os.path.join(root_dir, f)
for f in tf.io.gfile.listdir(root_dir)
if any(f.lower().endswith(ext) for ext in _SUPPORTED_IMAGE_FORMAT)
]
Loading