Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GLU activation #20392

Merged
merged 3 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras.src.activations.activations import elu
from keras.src.activations.activations import exponential
from keras.src.activations.activations import gelu
from keras.src.activations.activations import glu
from keras.src.activations.activations import hard_sigmoid
from keras.src.activations.activations import hard_silu
from keras.src.activations.activations import hard_silu as hard_swish
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from keras.src.ops.nn import dot_product_attention
from keras.src.ops.nn import elu
from keras.src.ops.nn import gelu
from keras.src.ops.nn import glu
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.src.ops.nn import dot_product_attention
from keras.src.ops.nn import elu
from keras.src.ops.nn import gelu
from keras.src.ops.nn import glu
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
Expand Down
1 change: 1 addition & 0 deletions keras/api/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras.src.activations.activations import elu
from keras.src.activations.activations import exponential
from keras.src.activations.activations import gelu
from keras.src.activations.activations import glu
from keras.src.activations.activations import hard_sigmoid
from keras.src.activations.activations import hard_silu
from keras.src.activations.activations import hard_silu as hard_swish
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from keras.src.ops.nn import dot_product_attention
from keras.src.ops.nn import elu
from keras.src.ops.nn import gelu
from keras.src.ops.nn import glu
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.src.ops.nn import dot_product_attention
from keras.src.ops.nn import elu
from keras.src.ops.nn import gelu
from keras.src.ops.nn import glu
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
Expand Down
2 changes: 2 additions & 0 deletions keras/src/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras.src.activations.activations import elu
from keras.src.activations.activations import exponential
from keras.src.activations.activations import gelu
from keras.src.activations.activations import glu
from keras.src.activations.activations import hard_sigmoid
from keras.src.activations.activations import hard_silu
from keras.src.activations.activations import leaky_relu
Expand Down Expand Up @@ -35,6 +36,7 @@
softsign,
silu,
gelu,
glu,
tanh,
sigmoid,
exponential,
Expand Down
21 changes: 21 additions & 0 deletions keras/src/activations/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,27 @@ def celu(x, alpha=1.0):
return ops.celu(x, alpha=alpha)


@keras_export("keras.activations.glu")
def glu(x, axis=-1):
"""Gated Linear Unit (GLU) activation function.

The GLU activation function is defined as:

`glu(x) = a * sigmoid(b)`,

where `x` is split into two equal parts `a` and `b` along the given axis.

Args:
x: Input tensor.
axis: The axis along which to split the input tensor. Defaults to `-1`.

Reference:

- [Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)
"""
return ops.glu(x, axis=axis)


@keras_export("keras.activations.tanh")
def tanh(x):
"""Hyperbolic tangent activation function.
Expand Down
15 changes: 15 additions & 0 deletions keras/src/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,21 @@ def celu(x, alpha=1.0):
expected = celu(x, True)
self.assertAllClose(result, expected, rtol=1e-05)

def test_glu(self):
def glu(x, axis=-1):
x1, x2 = np.split(x, 2, axis)
return x1 * (1 / (1 + np.exp(-x2)))

x = np.random.random((2, 4))
result = activations.glu(x[np.newaxis, :])[0]
expected = glu(x)
self.assertAllClose(result, expected, rtol=1e-05)

x = np.random.random((2, 4))
result = activations.glu(x[np.newaxis, :], axis=-2)[0]
expected = glu(x, axis=-2)
self.assertAllClose(result, expected, rtol=1e-05)

def test_elu(self):
x = np.random.random((2, 5))
result = activations.elu(x[np.newaxis, :])[0]
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def celu(x, alpha=1.0):
return jnn.celu(x, alpha=alpha)


def glu(x, axis=-1):
x = convert_to_tensor(x)
return jnn.glu(x, axis=axis)


def softmax(x, axis=-1):
x = convert_to_tensor(x)
return jnn.softmax(x, axis=axis)
Expand Down
11 changes: 11 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ def celu(x, alpha=1.0):
)


def glu(x, axis=-1):
x = convert_to_tensor(x)
if x.shape[axis] % 2 != 0:
raise ValueError(
"axis size must be divisible by 2. "
f"Received: x.shape={x.shape} with axis={axis}"
)
x1, x2 = np.split(x, 2, axis)
return x1 * (1 / (1 + np.exp(-x2)))


def softmax(x, axis=None):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
Expand Down
10 changes: 10 additions & 0 deletions keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def celu(x, alpha=1.0):
)


def glu(x, axis=-1):
if x.shape[axis] % 2 != 0:
raise ValueError(
"axis size must be divisible by 2. "
f"Received: x.shape={x.shape} with axis={axis}"
)
x1, x2 = tf.split(x, num_or_size_splits=2, axis=axis)
return x1 * tf.sigmoid(x2)


def softmax(x, axis=-1):
logits = x
if axis is None:
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def celu(x, alpha=1.0):
return tnn.celu(x, alpha=alpha)


def glu(x, axis=-1):
x = convert_to_tensor(x)
return tnn.glu(x, dim=axis)


def softmax(x, axis=-1):
x = convert_to_tensor(x)
dtype = backend.standardize_dtype(x.dtype)
Expand Down
41 changes: 41 additions & 0 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,47 @@ def celu(x, alpha=1.0):
return backend.nn.celu(x, alpha)


class Glu(Operation):
def __init__(self, axis=-1):
super().__init__()
self.axis = axis

def call(self, x):
return backend.nn.glu(x, axis=self.axis)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)


@keras_export(["keras.ops.glu", "keras.ops.nn.glu"])
def glu(x, axis=-1):
"""Gated Linear Unit (GLU) activation function.

It is defined as:

`f(x) = a * sigmoid(b)`
where `x` is split into `a` and `b` along the given axis.

Args:
x: Input tensor.
axis: The axis along which to split the input tensor. Defaults to `-1`.

Returns:
A tensor with the same shape as half of the input.

Example:

>>> x = np.array([-1., 0., 1. , 1.])
>>> x_glu = keras.ops.glu(x)
>>> print(x_glu)
array([-0.73105858, 0. ], shape=(2,), dtype=float64)

"""
if any_symbolic_tensors((x,)):
return Glu(axis).symbolic_call(x)
return backend.nn.glu(x, axis=axis)


class Softmax(Operation):
def __init__(self, axis=-1):
super().__init__()
Expand Down
33 changes: 33 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def test_celu(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.celu(x).shape, (None, 2, 3))

def test_glu(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.glu(x).shape, (None, 2, 3))

def test_softmax(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.softmax(x).shape, (None, 2, 3))
Expand Down Expand Up @@ -794,6 +798,10 @@ def test_celu(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.celu(x).shape, (1, 2, 3))

def test_glu(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.glu(x).shape, (1, 2, 3))

def test_softmax(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.softmax(x).shape, (1, 2, 3))
Expand Down Expand Up @@ -1307,6 +1315,13 @@ def test_celu(self):
[-0.63212055, 0.0, 1.0, 2.0, 3.0],
)

def test_glu(self):
x = np.array([-1, 0, 1, 2, 3, 4], dtype=np.float32)
self.assertAllClose(
knn.glu(x),
[-0.8807971, 0.0, 0.98201376],
)

def test_softmax(self):
x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
self.assertAllClose(
Expand Down Expand Up @@ -2396,6 +2411,24 @@ def test_celu(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
def test_glu(self, dtype):
import jax.nn as jnn
import jax.numpy as jnp

x = knp.ones((2), dtype=dtype)
x_jax = jnp.ones((2), dtype=dtype)
expected_dtype = standardize_dtype(jnn.glu(x_jax).dtype)

self.assertEqual(
standardize_dtype(knn.glu(x).dtype),
expected_dtype,
)
self.assertEqual(
standardize_dtype(knn.Glu().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
def test_hard_sigmoid(self, dtype):
import jax.nn as jnn
Expand Down