Skip to content

Commit

Permalink
Split Transform normalization into separate files to isolate dependen…
Browse files Browse the repository at this point in the history
…cies on tfgrain.

PiperOrigin-RevId: 677830924
  • Loading branch information
Qwlouse authored and The kauldron Authors committed Sep 23, 2024
1 parent 14744bd commit c3f62c1
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 145 deletions.
40 changes: 5 additions & 35 deletions kauldron/data/py/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
import dataclasses
import functools
import typing
Expand All @@ -28,11 +27,10 @@
from kauldron import random
from kauldron.data import iterators
from kauldron.data import pipelines
from kauldron.data.py import transform_utils
from kauldron.data.transforms import normalize as tr_normalize
from kauldron.typing import PRNGKeyLike # pylint: disable=g-importing-member,g-multiple-import

_Transforms = Sequence[grain.Transformation] | dict[str, grain.Transformation]


@dataclasses.dataclass(frozen=True, kw_only=True, eq=True)
class PyGrainPipeline(pipelines.Pipeline):
Expand Down Expand Up @@ -60,7 +58,9 @@ class PyGrainPipeline(pipelines.Pipeline):
batch_size: int | None = ...
seed: PRNGKeyLike | None = ...

transforms: _Transforms = dataclasses.field(default_factory=tuple)
transforms: tr_normalize.Transformations = dataclasses.field(
default_factory=tuple
)

# Params only relevant for the root top-level dataset (when dataset mixture)
num_epochs: Optional[int] = None
Expand All @@ -82,7 +82,7 @@ def ds_with_transforms(self, rng: random.PRNGKey) -> grain.MapDataset:
"""Create the `tf.data.Dataset` and apply all the transforms."""
ds = self.ds_for_current_process(rng)

ds = _apply_transforms(ds, self.transforms)
ds = transform_utils.apply_transforms(ds, self.transforms)

if self.batch_size:
ds = ds.batch(self.batch_size, drop_remainder=self.batch_drop_remainder)
Expand Down Expand Up @@ -179,33 +179,3 @@ def _get_num_workers(num_workers: int) -> int:
return 0
else:
return num_workers


def _apply_transforms(
ds: grain.MapDataset, transforms: _Transforms
) -> grain.MapDataset:
"""Apply the transformations to the dataset."""
if isinstance(transforms, Mapping):
transforms = transforms.values()
for tr in transforms:
tr = tr_normalize.adapt_for_pygrain(tr)
ds = _apply_transform(ds, tr)
return ds


def _apply_transform(
ds: grain.MapDataset, tr: grain.Transformation
) -> grain.MapDataset:
"""Apply a list of single transformation."""
match tr:
case grain.MapTransform():
ds = ds.map(tr)
case grain.RandomMapTransform():
ds = ds.random_map(tr)
case grain.FilterTransform():
ds = ds.filter(tr)
case grain.Batch():
ds = ds.batch(tr.batch_size, drop_remainder=tr.drop_remainder)
case _:
raise ValueError(f"Unexpected transform type: {tr}")
return ds
89 changes: 89 additions & 0 deletions kauldron/data/py/transform_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utils for using Kauldron transforms with PyGrain."""

from typing import Any, Callable, Mapping

import grain.python as grain
from kauldron.data.transforms import abc as tr_abc
from kauldron.data.transforms import normalize as tr_normalize


class PyGrainMapAdapter(tr_normalize.TransformAdapter, grain.MapTransform):
"""Adapter from `kd.data.MapTransform` to pygrain."""

def map(self, element: Any) -> Any:
return self.transform.map(element)


class PyGrainFilterAdapter(
tr_normalize.TransformAdapter, grain.FilterTransform
):
"""Adapter from `kd.data.FilterTransform` to pygrain."""

def filter(self, element: Any) -> bool:
return self.transform.filter(element)


class PyGrainCallableAdapter(tr_normalize.TransformAdapter, grain.MapTransform):
"""Adapter for any callable to a pygrain MapTransform."""

def map(self, element: Any) -> Any:
return self.transform(element)


_KD_TO_PYGRAIN_ADAPTERS = {
tr_abc.MapTransform: PyGrainMapAdapter,
tr_abc.FilterTransform: PyGrainFilterAdapter,
Callable: PyGrainCallableAdapter,
}


def _adapt_for_pygrain(
transform: tr_normalize.Transformation,
) -> grain.Transformation:
if isinstance(transform, grain.Transformation):
return transform
return tr_normalize.adapt_transform(transform, _KD_TO_PYGRAIN_ADAPTERS)


def apply_transforms(
ds: grain.MapDataset, transforms: tr_normalize.Transformations
) -> grain.MapDataset:
"""Apply the transformations to the dataset."""
if isinstance(transforms, Mapping):
transforms = transforms.values()
for tr in transforms:
tr = _adapt_for_pygrain(tr)
ds = _apply_transform(ds, tr)
return ds


def _apply_transform(
ds: grain.MapDataset, tr: grain.Transformation
) -> grain.MapDataset:
"""Apply a list of single transformation."""
match tr:
case grain.MapTransform():
ds = ds.map(tr)
case grain.RandomMapTransform():
ds = ds.random_map(tr)
case grain.FilterTransform():
ds = ds.filter(tr)
case grain.Batch():
ds = ds.batch(tr.batch_size, drop_remainder=tr.drop_remainder)
case _:
raise ValueError(f"Unexpected transform type: {tr}")
return ds
79 changes: 72 additions & 7 deletions kauldron/data/tf/transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transform utils."""
"""Utils for using Kauldron transforms with tfgrain."""

import functools
from typing import Any
from typing import Any, Callable, Mapping

from grain._src.tensorflow import transforms as grain_transforms
import grain.tensorflow as grain
Expand All @@ -25,17 +25,82 @@
import tensorflow as tf


# Kauldron transforms for kd.data.tf supports both `kd.data.MapTransform` and
# `tfgrain.MapTransform`
Transformation = grain.Transformation | tr_abc.Transformation
class TfGrainMapAdapter(tr_normalize.TransformAdapter, grain.MapTransform):
"""Adapter for `kd.data.MapTransform` to tfgrain."""

@property
def name(self):
"""Forward the name of this transformation (if any), to aid in debugging."""
# Used by tfgrain to name the operations in the tf graph.
return getattr(self.transform, 'name', getattr(super(), 'name'))

@property
def num_parallel_calls_hint(self):
"""Forward the num_parallel_calls_hint of this transformation (if any)."""
# Can be used to modify the default parallelization behavior of tfgrain.
return getattr(
self.transform,
'num_parallel_calls_hint',
getattr(super(), 'num_parallel_calls_hint'),
)

def map(self, element: Any) -> Any:
# Required due to b/326590491.
meta_features, ex_features = grain_utils.split_grain_meta_features(element)
out = self.transform.map(ex_features)
return grain_utils.merge_grain_meta_features(meta_features, out)


class TfGrainCallableAdapter(tr_normalize.TransformAdapter, grain.MapTransform):
"""Adapter for any callable to a tfgrain MapTransform."""

def map(self, element: Any) -> Any:
# Required due to b/326590491.
meta_features, ex_features = grain_utils.split_grain_meta_features(element)
out = self.transform(ex_features)
return grain_utils.merge_grain_meta_features(meta_features, out)


class TfGrainFilterAdapter(
tr_normalize.TransformAdapter, grain.FilterTransform
):
"""Adapter from `kd.data.FilterTransform` to tfgrain."""

@property
def name(self):
"""Forward the name of this transformation (if any), to aid in debugging."""
# Used by tfgrain to name the operations in the tf graph.
return getattr(self.transform, 'name', getattr(super(), 'name'))

def filter(self, elements: Any) -> Any:
# Required due to b/326590491.
_, ex_features = grain_utils.split_grain_meta_features(elements)
return self.transform.filter(ex_features)


_KD_TO_TFGRAIN_ADAPTERS = {
tr_abc.MapTransform: TfGrainMapAdapter,
tr_abc.FilterTransform: TfGrainFilterAdapter,
Callable: TfGrainCallableAdapter, # support grand-vision preprocessing ops
}


def _adapt_for_tfgrain(
transform: tr_normalize.Transformation,
) -> grain.Transformation:
if isinstance(transform, grain.Transformation):
return transform
return tr_normalize.adapt_transform(transform, _KD_TO_TFGRAIN_ADAPTERS)


def apply_transformations(
ds: tf.data.Dataset,
transforms: list[tr_normalize.Transformation],
transforms: tr_normalize.Transformations,
) -> tf.data.Dataset:
"""Wrapper around grain to apply the transformations."""
transforms = [tr_normalize.adapt_for_tfgrain(tr) for tr in transforms]
if isinstance(transforms, Mapping):
transforms = transforms.values()
transforms = [_adapt_for_tfgrain(tr) for tr in transforms]
return grain_transforms.apply_transformations(ds, transforms, strict=True)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from grain import tensorflow as grain
from kauldron import kd
from kauldron.data.tf import grain_utils
from kauldron.data.transforms import normalize
from kauldron.data.tf import transform_utils
import pytest
import tensorflow as tf

Expand Down Expand Up @@ -66,7 +66,7 @@ def test_source(is_supervised: bool):
else {'image': None, 'label': None}
)
assert not isinstance(tr, grain.MapTransform)
tr = normalize.adapt_for_tfgrain(tr)
tr = transform_utils._adapt_for_tfgrain(tr)
assert isinstance(tr, grain.MapTransform)
data_loader = grain.TfDataLoader(
source=source,
Expand Down
Loading

0 comments on commit c3f62c1

Please sign in to comment.