From e05bcf5fd21e6a44d462e8b8a3d2b3a1825b0862 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Tue, 22 Oct 2024 20:46:32 +0900 Subject: [PATCH 1/3] Add GLU activation function --- .../_tf_keras/keras/activations/__init__.py | 1 + keras/api/_tf_keras/keras/ops/__init__.py | 1 + keras/api/_tf_keras/keras/ops/nn/__init__.py | 1 + keras/api/activations/__init__.py | 1 + keras/api/ops/__init__.py | 1 + keras/api/ops/nn/__init__.py | 1 + keras/src/activations/activations.py | 21 ++++++++++ keras/src/backend/jax/nn.py | 5 +++ keras/src/backend/numpy/nn.py | 7 ++++ keras/src/backend/tensorflow/nn.py | 6 +++ keras/src/backend/torch/nn.py | 5 +++ keras/src/ops/nn.py | 41 +++++++++++++++++++ 12 files changed, 91 insertions(+) diff --git a/keras/api/_tf_keras/keras/activations/__init__.py b/keras/api/_tf_keras/keras/activations/__init__.py index a56def1a208..5d562abd630 100644 --- a/keras/api/_tf_keras/keras/activations/__init__.py +++ b/keras/api/_tf_keras/keras/activations/__init__.py @@ -11,6 +11,7 @@ from keras.src.activations.activations import elu from keras.src.activations.activations import exponential from keras.src.activations.activations import gelu +from keras.src.activations.activations import glu from keras.src.activations.activations import hard_sigmoid from keras.src.activations.activations import hard_silu from keras.src.activations.activations import hard_silu as hard_swish diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 12a8571fd7d..29759d45a39 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -71,6 +71,7 @@ from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu from keras.src.ops.nn import gelu +from keras.src.ops.nn import glu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu from keras.src.ops.nn import hard_silu as hard_swish diff --git a/keras/api/_tf_keras/keras/ops/nn/__init__.py b/keras/api/_tf_keras/keras/ops/nn/__init__.py index 49683dc70bd..faa269f4c3d 100644 --- a/keras/api/_tf_keras/keras/ops/nn/__init__.py +++ b/keras/api/_tf_keras/keras/ops/nn/__init__.py @@ -17,6 +17,7 @@ from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu from keras.src.ops.nn import gelu +from keras.src.ops.nn import glu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu from keras.src.ops.nn import hard_silu as hard_swish diff --git a/keras/api/activations/__init__.py b/keras/api/activations/__init__.py index a56def1a208..5d562abd630 100644 --- a/keras/api/activations/__init__.py +++ b/keras/api/activations/__init__.py @@ -11,6 +11,7 @@ from keras.src.activations.activations import elu from keras.src.activations.activations import exponential from keras.src.activations.activations import gelu +from keras.src.activations.activations import glu from keras.src.activations.activations import hard_sigmoid from keras.src.activations.activations import hard_silu from keras.src.activations.activations import hard_silu as hard_swish diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 12a8571fd7d..29759d45a39 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -71,6 +71,7 @@ from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu from keras.src.ops.nn import gelu +from keras.src.ops.nn import glu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu from keras.src.ops.nn import hard_silu as hard_swish diff --git a/keras/api/ops/nn/__init__.py b/keras/api/ops/nn/__init__.py index 49683dc70bd..faa269f4c3d 100644 --- a/keras/api/ops/nn/__init__.py +++ b/keras/api/ops/nn/__init__.py @@ -17,6 +17,7 @@ from keras.src.ops.nn import dot_product_attention from keras.src.ops.nn import elu from keras.src.ops.nn import gelu +from keras.src.ops.nn import glu from keras.src.ops.nn import hard_sigmoid from keras.src.ops.nn import hard_silu from keras.src.ops.nn import hard_silu as hard_swish diff --git a/keras/src/activations/activations.py b/keras/src/activations/activations.py index 8dc56ee43b7..78ae14ba473 100644 --- a/keras/src/activations/activations.py +++ b/keras/src/activations/activations.py @@ -323,6 +323,27 @@ def celu(x, alpha=1.0): return ops.celu(x, alpha=alpha) +@keras_export("keras.activations.glu") +def glu(x, axis=-1): + """Gated Linear Unit (GLU) activation function. + + The GLU activation function is defined as: + + `glu(x) = a * sigmoid(b)`, + + where `x` is split into two equal parts `a` and `b` along the given axis. + + Args: + x: Input tensor. + axis: The axis along which to split the input tensor. Defaults to `-1`. + + Reference: + + - [Dauphin et al., 2017](https://arxiv.org/abs/1612.08083) + """ + return ops.glu(x, axis=axis) + + @keras_export("keras.activations.tanh") def tanh(x): """Hyperbolic tangent activation function. diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 9899f1b65d5..0ae1e2b42a6 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -90,6 +90,11 @@ def celu(x, alpha=1.0): return jnn.celu(x, alpha=alpha) +def glu(x, axis=-1): + x = convert_to_tensor(x) + return jnn.glu(x, axis=axis) + + def softmax(x, axis=-1): x = convert_to_tensor(x) return jnn.softmax(x, axis=axis) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index 6e3f8203957..9f5e4674274 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -121,6 +121,13 @@ def celu(x, alpha=1.0): ) +def glu(x, axis=-1): + x = convert_to_tensor(x) + assert x.shape[axis] % 2 == 0, "axis size must be divisible by 2" + x1, x2 = np.split(x, 2, axis) + return x1 * (1 / (1 + np.exp(-x2))) + + def softmax(x, axis=None): exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) return exp_x / np.sum(exp_x, axis=axis, keepdims=True) diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index 7c16e5e901b..70eb89ffcbe 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -82,6 +82,12 @@ def celu(x, alpha=1.0): ) +def glu(x, axis=-1): + assert x.shape[axis] % 2 == 0, "axis size must be divisible by 2" + x1, x2 = tf.split(x, num_or_size_splits=2, axis=axis) + return x1 * tf.sigmoid(x2) + + def softmax(x, axis=-1): logits = x if axis is None: diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index 7c253988480..cd56bad7c53 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -93,6 +93,11 @@ def celu(x, alpha=1.0): return tnn.celu(x, alpha=alpha) +def glu(x, axis=-1): + x = convert_to_tensor(x) + return tnn.glu(x, dim=axis) + + def softmax(x, axis=-1): x = convert_to_tensor(x) dtype = backend.standardize_dtype(x.dtype) diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index 0531f87f869..d9681d6a0cb 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -538,6 +538,47 @@ def celu(x, alpha=1.0): return backend.nn.celu(x, alpha) +class Glu(Operation): + def __init__(self, axis=-1): + super().__init__() + self.axis = axis + + def call(self, x): + return backend.nn.glu(x, axis=self.axis) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.glu", "keras.ops.nn.glu"]) +def glu(x, axis=-1): + """Gated Linear Unit (GLU) activation function. + + It is defined as: + + `f(x) = a * sigmoid(b)` + where `x` is split into `a` and `b` along the given axis. + + Args: + x: Input tensor. + axis: The axis along which to split the input tensor. Defaults to `-1`. + + Returns: + A tensor with the same shape as half of the input. + + Example: + + >>> x = np.array([-1., 0., 1. , 1.]) + >>> x_glu = keras.ops.glu(x) + >>> print(x_glu) + array([-0.73105858, 0. ], shape=(2,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return Glu(axis).symbolic_call(x) + return backend.nn.glu(x, axis=axis) + + class Softmax(Operation): def __init__(self, axis=-1): super().__init__() From 967f4ad96ab6b6211a38269b53aeb62b4d38109f Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Wed, 23 Oct 2024 20:04:59 +0900 Subject: [PATCH 2/3] Add test cases for GLU --- keras/src/activations/__init__.py | 2 ++ keras/src/activations/activations_test.py | 15 +++++++++++ keras/src/ops/nn_test.py | 33 +++++++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/keras/src/activations/__init__.py b/keras/src/activations/__init__.py index 57cd085a173..4266512e2a8 100644 --- a/keras/src/activations/__init__.py +++ b/keras/src/activations/__init__.py @@ -4,6 +4,7 @@ from keras.src.activations.activations import elu from keras.src.activations.activations import exponential from keras.src.activations.activations import gelu +from keras.src.activations.activations import glu from keras.src.activations.activations import hard_sigmoid from keras.src.activations.activations import hard_silu from keras.src.activations.activations import leaky_relu @@ -35,6 +36,7 @@ softsign, silu, gelu, + glu, tanh, sigmoid, exponential, diff --git a/keras/src/activations/activations_test.py b/keras/src/activations/activations_test.py index 045ffab14d8..00d8167239c 100644 --- a/keras/src/activations/activations_test.py +++ b/keras/src/activations/activations_test.py @@ -598,6 +598,21 @@ def celu(x, alpha=1.0): expected = celu(x, True) self.assertAllClose(result, expected, rtol=1e-05) + def test_glu(self): + def glu(x, axis=-1): + x1, x2 = np.split(x, 2, axis) + return x1 * (1 / (1 + np.exp(-x2))) + + x = np.random.random((2, 4)) + result = activations.glu(x[np.newaxis, :])[0] + expected = glu(x) + self.assertAllClose(result, expected, rtol=1e-05) + + x = np.random.random((2, 4)) + result = activations.glu(x[np.newaxis, :], axis=-2)[0] + expected = glu(x, axis=-2) + self.assertAllClose(result, expected, rtol=1e-05) + def test_elu(self): x = np.random.random((2, 5)) result = activations.elu(x[np.newaxis, :])[0] diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index a14a32b46f0..8335947c6b8 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -145,6 +145,10 @@ def test_celu(self): x = KerasTensor([None, 2, 3]) self.assertEqual(knn.celu(x).shape, (None, 2, 3)) + def test_glu(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.glu(x).shape, (None, 2, 3)) + def test_softmax(self): x = KerasTensor([None, 2, 3]) self.assertEqual(knn.softmax(x).shape, (None, 2, 3)) @@ -794,6 +798,10 @@ def test_celu(self): x = KerasTensor([1, 2, 3]) self.assertEqual(knn.celu(x).shape, (1, 2, 3)) + def test_glu(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.glu(x).shape, (1, 2, 3)) + def test_softmax(self): x = KerasTensor([1, 2, 3]) self.assertEqual(knn.softmax(x).shape, (1, 2, 3)) @@ -1307,6 +1315,13 @@ def test_celu(self): [-0.63212055, 0.0, 1.0, 2.0, 3.0], ) + def test_glu(self): + x = np.array([-1, 0, 1, 2, 3, 4], dtype=np.float32) + self.assertAllClose( + knn.glu(x), + [-0.8807971, 0.0, 0.98201376], + ) + def test_softmax(self): x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) self.assertAllClose( @@ -2396,6 +2411,24 @@ def test_celu(self, dtype): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_glu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((2), dtype=dtype) + x_jax = jnp.ones((2), dtype=dtype) + expected_dtype = standardize_dtype(jnn.glu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.glu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Glu().symbolic_call(x).dtype), + expected_dtype, + ) + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) def test_hard_sigmoid(self, dtype): import jax.nn as jnn From 1e4fc2ddddb284f8cfc5cdc50b058c16dc83ac5a Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Thu, 24 Oct 2024 19:19:22 +0900 Subject: [PATCH 3/3] Update assert statement to ValueError --- keras/src/backend/numpy/nn.py | 6 +++++- keras/src/backend/tensorflow/nn.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index 9f5e4674274..b1f30f20b84 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -123,7 +123,11 @@ def celu(x, alpha=1.0): def glu(x, axis=-1): x = convert_to_tensor(x) - assert x.shape[axis] % 2 == 0, "axis size must be divisible by 2" + if x.shape[axis] % 2 != 0: + raise ValueError( + "axis size must be divisible by 2. " + f"Received: x.shape={x.shape} with axis={axis}" + ) x1, x2 = np.split(x, 2, axis) return x1 * (1 / (1 + np.exp(-x2))) diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index 70eb89ffcbe..ad3584ee4c2 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -83,7 +83,11 @@ def celu(x, alpha=1.0): def glu(x, axis=-1): - assert x.shape[axis] % 2 == 0, "axis size must be divisible by 2" + if x.shape[axis] % 2 != 0: + raise ValueError( + "axis size must be divisible by 2. " + f"Received: x.shape={x.shape} with axis={axis}" + ) x1, x2 = tf.split(x, num_or_size_splits=2, axis=axis) return x1 * tf.sigmoid(x2)