diff --git a/scenic/projects/performer/performer.py b/scenic/projects/performer/performer.py index 71ce2f95..91c8522c 100644 --- a/scenic/projects/performer/performer.py +++ b/scenic/projects/performer/performer.py @@ -17,15 +17,14 @@ # pylint: disable=invalid-name import abc +import functools import math from typing import (Any, Dict, Optional, Sequence, Tuple, Union) from flax.linen.linear import PrecisionLike - import jax from jax import random import jax.numpy as jnp - from scenic.projects.performer import subquadratic_attention as sat from scenic.projects.performer import utils as ut @@ -43,6 +42,14 @@ Array = Any +def linear_gaussian(x): + x_norm = jnp.linalg.norm(x, axis=-1, keepdims=True) + x_sq_norm = x_norm**2 + return jnp.exp(-0.5 * x_sq_norm) * jax.nn.relu( + x + ) # instead use RF based transformation + + class RandomMatrix(abc.ABC): r"""Abstract class providing a method for constructing 2D random arrays. @@ -398,13 +405,13 @@ def expplus_softmax_kernel_transformation( return data_dash -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ # Performers-compatible Relative Positional Encoding mechanism. # # The implementation is taken from the following paper: "Relative Positional # Encoding for Transformers with Linear Complexity" # (github code: https://cifkao.github.io/spe/) -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ def sinespe(rng_key, @@ -758,8 +765,12 @@ def favor_attention(query, key, value, inputs_mask) query = query[:, hybrid_global_size:, :, :] if not data_dependent_kfs: - query_prime = kernel_transformation(query, True, projection_matrix) - key_prime = kernel_transformation(key, False, projection_matrix) + query_prime = kernel_transformation( + data=query, is_query=True, projection_matrix=projection_matrix + ) + key_prime = kernel_transformation( + data=key, is_query=False, projection_matrix=projection_matrix + ) else: query_prime = kernel_transformation(query, key, True, projection_matrix) key_prime = kernel_transformation(key, query, False, projection_matrix) @@ -849,6 +860,8 @@ def masked_favor_attention(query, key, value, masker, mask, kernel_config): if kernel_config['kernel_transformation'] == 'softmax': kernel_transformation = exp_softmax_kernel_transformation + elif kernel_config['kernel_transformation'] == 'linear_gaussian': + kernel_transformation = sat.softmax_positive_rfs else: if kernel_config['kernel_transformation'] == 'relu': activation_fn = jax.nn.relu @@ -1198,16 +1211,17 @@ def regular_performer_dot_product_attention( if kernel_config['kernel_transformation'] == 'softmax': kernel_transformation = exp_softmax_kernel_transformation + elif kernel_config['kernel_transformation'] == 'linear_gaussian': + kernel_transformation = sat.softmax_positive_rfs else: if kernel_config['kernel_transformation'] == 'relu': activation_fn = jax.nn.relu else: activation_fn = (lambda x: x * x * x * x) - def gen_transformation(a, b, c): - return generic_kernel_transformation(a, b, c, activation_fn=activation_fn) - - kernel_transformation = gen_transformation + kernel_transformation = functools.partial( + generic_kernel_transformation, activation_fn=activation_fn + ) return favor_attention( query, key, @@ -1337,6 +1351,8 @@ def sharp_masked_performer_dot_product_attention( del precision if kernel_config['kernel_transformation'] == 'softmax': kernel_transformation = exp_softmax_kernel_transformation + elif kernel_config['kernel_transformation'] == 'linear_gaussian': + kernel_transformation = sat.softmax_positive_rfs else: if kernel_config['kernel_transformation'] == 'relu': activation_fn = jax.nn.relu @@ -1423,4 +1439,3 @@ def pseudolocal_subquadratic_attention( ) result = numerator / denominator return result - diff --git a/scenic/projects/performer/subquadratic_attention.py b/scenic/projects/performer/subquadratic_attention.py index ffa34fc3..2323fb84 100644 --- a/scenic/projects/performer/subquadratic_attention.py +++ b/scenic/projects/performer/subquadratic_attention.py @@ -62,7 +62,7 @@ def softmax_positive_rfs( projection_matrix: jax.Array | None = None, numerical_stabilizer: float = 0.000001, is_query: bool = True, - + temp: float = 5.0, ) -> jax.Array: r"""Computes positive random features from https://arxiv.org/abs/2009.14794. @@ -73,17 +73,20 @@ def softmax_positive_rfs( stands for the number of projections. numerical_stabilizer: small positive constant used for numerical stability. is_query: determines whether input data tensor is a query- or key-tensor. + temp: temperature parameter of the softmax kernel. Returns: Corresponding kernel feature map used to linearize softmax kernel. """ - h = lambda X: jnp.exp(-0.5 * jnp.sum(jnp.square(X), axis=-1, keepdims=True)) + h = lambda X: jnp.exp( + -0.5 * temp * jnp.sum(jnp.square(X), axis=-1, keepdims=True) + ) if is_query: axis = (-1,) else: axis = None activation_fn = lambda P, X: h(X) * jnp.exp( - P - jnp.max(P, axis=axis, keepdims=True) + temp * (P - jnp.max(P, axis=axis, keepdims=True)) ) return general_kernel_linearization( data, projection_matrix, numerical_stabilizer, activation_fn @@ -128,4 +131,3 @@ def softmax_hyper_positive_rfs( data, projection_matrix, numerical_stabilizer, negative_activation_fn ) return jnp.concatenate((positive_exponential, negative_exponential), axis=-1) - diff --git a/scenic/projects/pointcloud/models.py b/scenic/projects/pointcloud/models.py index 5f4835ff..ba169434 100644 --- a/scenic/projects/pointcloud/models.py +++ b/scenic/projects/pointcloud/models.py @@ -95,21 +95,22 @@ def __call__(self, output: Tensor of shape [batch_size, num_points, feature_dim] """ key_channels = self.key_channels or self.out_channels - input_q = nn.Conv( - key_channels, - kernel_size=(self.kernel_size,), - use_bias=True)( - inputs) - input_k = nn.Conv( - key_channels, - kernel_size=(self.kernel_size,), - use_bias=True)( - inputs) - input_v = nn.Conv( - self.out_channels, - kernel_size=(self.kernel_size,), - use_bias=True)( - inputs) + if (self.attention_fn_configs is not None) and self.attention_fn_configs[ + 'neighbor_attn' + ]: + input_q = coords + input_k = coords + input_v = inputs + else: + input_q = nn.Conv( + key_channels, kernel_size=(self.kernel_size,), use_bias=True + )(inputs) + input_k = nn.Conv( + key_channels, kernel_size=(self.kernel_size,), use_bias=True + )(inputs) + input_v = nn.Conv( + self.out_channels, kernel_size=(self.kernel_size,), use_bias=True + )(inputs) if ( self.attention_fn_configs is None @@ -303,6 +304,7 @@ class PointCloudTransformerClassifier(nn.Module): @nn.compact def __call__(self, inputs, train: bool = False, debug: bool = False): + output = inputs if self.attention_type == 'standard': output = PointCloudTransformerEncoder( in_dim=self.in_dim,