From b3d0eaad4529e971b8e640458428fbaa59c0ff1f Mon Sep 17 00:00:00 2001 From: Adithya Kamath Date: Wed, 19 Jul 2023 19:38:17 +0530 Subject: [PATCH] Random rotations pt-i --- .../layers/preprocessing/random_rotation.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/keras_core/layers/preprocessing/random_rotation.py b/keras_core/layers/preprocessing/random_rotation.py index 6cf08e439..6fd9ab475 100644 --- a/keras_core/layers/preprocessing/random_rotation.py +++ b/keras_core/layers/preprocessing/random_rotation.py @@ -3,6 +3,7 @@ from keras_core import backend from keras_core.api_export import keras_core_export from keras_core.layers.layer import Layer +from keras_core.random.seed_generator import SeedGenerator from keras_core.utils import backend_utils from keras_core.utils.module_utils import tensorflow as tf @@ -84,16 +85,11 @@ def __init__( seed=None, fill_value=0.0, name=None, + data_format=None, **kwargs, ): - if not tf.available: - raise ImportError( - "Layer RandomRotation requires TensorFlow. " - "Install it via `pip install tensorflow`." - ) - super().__init__(name=name, **kwargs) - self.seed = seed or backend.random.make_default_seed() + self.seed = SeedGenerator(seed) self.layer = tf.keras.layers.RandomRotation( factor=factor, fill_mode=fill_mode, @@ -103,6 +99,7 @@ def __init__( name=name, **kwargs, ) + self.data_format = backend.standardize_data_format(data_format) self.supports_jit = False self._convert_input_args = False self._allow_non_tensor_positional_args = True @@ -122,6 +119,11 @@ def compute_output_shape(self, input_shape): return tuple(self.layer.compute_output_shape(input_shape)) def get_config(self): - config = self.layer.get_config() - config.update({"seed": self.seed}) + config = super().get_config() + config.update( + { + "seed": self.seed, + "data_format": self.data_format, + } + ) return config