Skip to content

Commit

Permalink
Add GLU activation (#20392)
Browse files Browse the repository at this point in the history
* Add GLU activation function

* Add test cases for GLU

* Update assert statement to ValueError
  • Loading branch information
shashaka authored Oct 24, 2024
1 parent 4b771bb commit a1de472
Show file tree
Hide file tree
Showing 15 changed files with 149 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras.src.activations.activations import elu
from keras.src.activations.activations import exponential
from keras.src.activations.activations import gelu
from keras.src.activations.activations import glu
from keras.src.activations.activations import hard_sigmoid
from keras.src.activations.activations import hard_silu
from keras.src.activations.activations import hard_silu as hard_swish
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from keras.src.ops.nn import dot_product_attention
from keras.src.ops.nn import elu
from keras.src.ops.nn import gelu
from keras.src.ops.nn import glu
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.src.ops.nn import dot_product_attention
from keras.src.ops.nn import elu
from keras.src.ops.nn import gelu
from keras.src.ops.nn import glu
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
Expand Down
1 change: 1 addition & 0 deletions keras/api/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras.src.activations.activations import elu
from keras.src.activations.activations import exponential
from keras.src.activations.activations import gelu
from keras.src.activations.activations import glu
from keras.src.activations.activations import hard_sigmoid
from keras.src.activations.activations import hard_silu
from keras.src.activations.activations import hard_silu as hard_swish
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from keras.src.ops.nn import dot_product_attention
from keras.src.ops.nn import elu
from keras.src.ops.nn import gelu
from keras.src.ops.nn import glu
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.src.ops.nn import dot_product_attention
from keras.src.ops.nn import elu
from keras.src.ops.nn import gelu
from keras.src.ops.nn import glu
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
Expand Down
2 changes: 2 additions & 0 deletions keras/src/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras.src.activations.activations import elu
from keras.src.activations.activations import exponential
from keras.src.activations.activations import gelu
from keras.src.activations.activations import glu
from keras.src.activations.activations import hard_sigmoid
from keras.src.activations.activations import hard_silu
from keras.src.activations.activations import leaky_relu
Expand Down Expand Up @@ -35,6 +36,7 @@
softsign,
silu,
gelu,
glu,
tanh,
sigmoid,
exponential,
Expand Down
21 changes: 21 additions & 0 deletions keras/src/activations/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,27 @@ def celu(x, alpha=1.0):
return ops.celu(x, alpha=alpha)


@keras_export("keras.activations.glu")
def glu(x, axis=-1):
"""Gated Linear Unit (GLU) activation function.
The GLU activation function is defined as:
`glu(x) = a * sigmoid(b)`,
where `x` is split into two equal parts `a` and `b` along the given axis.
Args:
x: Input tensor.
axis: The axis along which to split the input tensor. Defaults to `-1`.
Reference:
- [Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)
"""
return ops.glu(x, axis=axis)


@keras_export("keras.activations.tanh")
def tanh(x):
"""Hyperbolic tangent activation function.
Expand Down
15 changes: 15 additions & 0 deletions keras/src/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,21 @@ def celu(x, alpha=1.0):
expected = celu(x, True)
self.assertAllClose(result, expected, rtol=1e-05)

def test_glu(self):
def glu(x, axis=-1):
x1, x2 = np.split(x, 2, axis)
return x1 * (1 / (1 + np.exp(-x2)))

x = np.random.random((2, 4))
result = activations.glu(x[np.newaxis, :])[0]
expected = glu(x)
self.assertAllClose(result, expected, rtol=1e-05)

x = np.random.random((2, 4))
result = activations.glu(x[np.newaxis, :], axis=-2)[0]
expected = glu(x, axis=-2)
self.assertAllClose(result, expected, rtol=1e-05)

def test_elu(self):
x = np.random.random((2, 5))
result = activations.elu(x[np.newaxis, :])[0]
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def celu(x, alpha=1.0):
return jnn.celu(x, alpha=alpha)


def glu(x, axis=-1):
x = convert_to_tensor(x)
return jnn.glu(x, axis=axis)


def softmax(x, axis=-1):
x = convert_to_tensor(x)
return jnn.softmax(x, axis=axis)
Expand Down
11 changes: 11 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ def celu(x, alpha=1.0):
)


def glu(x, axis=-1):
x = convert_to_tensor(x)
if x.shape[axis] % 2 != 0:
raise ValueError(
"axis size must be divisible by 2. "
f"Received: x.shape={x.shape} with axis={axis}"
)
x1, x2 = np.split(x, 2, axis)
return x1 * (1 / (1 + np.exp(-x2)))


def softmax(x, axis=None):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
Expand Down
10 changes: 10 additions & 0 deletions keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def celu(x, alpha=1.0):
)


def glu(x, axis=-1):
if x.shape[axis] % 2 != 0:
raise ValueError(
"axis size must be divisible by 2. "
f"Received: x.shape={x.shape} with axis={axis}"
)
x1, x2 = tf.split(x, num_or_size_splits=2, axis=axis)
return x1 * tf.sigmoid(x2)


def softmax(x, axis=-1):
logits = x
if axis is None:
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def celu(x, alpha=1.0):
return tnn.celu(x, alpha=alpha)


def glu(x, axis=-1):
x = convert_to_tensor(x)
return tnn.glu(x, dim=axis)


def softmax(x, axis=-1):
x = convert_to_tensor(x)
dtype = backend.standardize_dtype(x.dtype)
Expand Down
41 changes: 41 additions & 0 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,47 @@ def celu(x, alpha=1.0):
return backend.nn.celu(x, alpha)


class Glu(Operation):
def __init__(self, axis=-1):
super().__init__()
self.axis = axis

def call(self, x):
return backend.nn.glu(x, axis=self.axis)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)


@keras_export(["keras.ops.glu", "keras.ops.nn.glu"])
def glu(x, axis=-1):
"""Gated Linear Unit (GLU) activation function.
It is defined as:
`f(x) = a * sigmoid(b)`
where `x` is split into `a` and `b` along the given axis.
Args:
x: Input tensor.
axis: The axis along which to split the input tensor. Defaults to `-1`.
Returns:
A tensor with the same shape as half of the input.
Example:
>>> x = np.array([-1., 0., 1. , 1.])
>>> x_glu = keras.ops.glu(x)
>>> print(x_glu)
array([-0.73105858, 0. ], shape=(2,), dtype=float64)
"""
if any_symbolic_tensors((x,)):
return Glu(axis).symbolic_call(x)
return backend.nn.glu(x, axis=axis)


class Softmax(Operation):
def __init__(self, axis=-1):
super().__init__()
Expand Down
33 changes: 33 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def test_celu(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.celu(x).shape, (None, 2, 3))

def test_glu(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.glu(x).shape, (None, 2, 3))

def test_softmax(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.softmax(x).shape, (None, 2, 3))
Expand Down Expand Up @@ -794,6 +798,10 @@ def test_celu(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.celu(x).shape, (1, 2, 3))

def test_glu(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.glu(x).shape, (1, 2, 3))

def test_softmax(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.softmax(x).shape, (1, 2, 3))
Expand Down Expand Up @@ -1307,6 +1315,13 @@ def test_celu(self):
[-0.63212055, 0.0, 1.0, 2.0, 3.0],
)

def test_glu(self):
x = np.array([-1, 0, 1, 2, 3, 4], dtype=np.float32)
self.assertAllClose(
knn.glu(x),
[-0.8807971, 0.0, 0.98201376],
)

def test_softmax(self):
x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
self.assertAllClose(
Expand Down Expand Up @@ -2396,6 +2411,24 @@ def test_celu(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
def test_glu(self, dtype):
import jax.nn as jnn
import jax.numpy as jnp

x = knp.ones((2), dtype=dtype)
x_jax = jnp.ones((2), dtype=dtype)
expected_dtype = standardize_dtype(jnn.glu(x_jax).dtype)

self.assertEqual(
standardize_dtype(knn.glu(x).dtype),
expected_dtype,
)
self.assertEqual(
standardize_dtype(knn.Glu().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
def test_hard_sigmoid(self, dtype):
import jax.nn as jnn
Expand Down

0 comments on commit a1de472

Please sign in to comment.