Skip to content

Commit

Permalink
Coordinate attention for neighbor encoding in HD-VPD
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 640478071
  • Loading branch information
Scenic Authors committed Aug 27, 2024
1 parent 6731c59 commit 852f250
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 30 deletions.
37 changes: 26 additions & 11 deletions scenic/projects/performer/performer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
# pylint: disable=invalid-name

import abc
import functools
import math
from typing import (Any, Dict, Optional, Sequence, Tuple, Union)

from flax.linen.linear import PrecisionLike

import jax
from jax import random
import jax.numpy as jnp

from scenic.projects.performer import subquadratic_attention as sat
from scenic.projects.performer import utils as ut

Expand All @@ -43,6 +42,14 @@
Array = Any


def linear_gaussian(x):
x_norm = jnp.linalg.norm(x, axis=-1, keepdims=True)
x_sq_norm = x_norm**2
return jnp.exp(-0.5 * x_sq_norm) * jax.nn.relu(
x
) # instead use RF based transformation


class RandomMatrix(abc.ABC):
r"""Abstract class providing a method for constructing 2D random arrays.
Expand Down Expand Up @@ -398,13 +405,13 @@ def expplus_softmax_kernel_transformation(
return data_dash


#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
# Performers-compatible Relative Positional Encoding mechanism.
#
# The implementation is taken from the following paper: "Relative Positional
# Encoding for Transformers with Linear Complexity"
# (github code: https://cifkao.github.io/spe/)
#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------


def sinespe(rng_key,
Expand Down Expand Up @@ -758,8 +765,12 @@ def favor_attention(query,
key, value, inputs_mask)
query = query[:, hybrid_global_size:, :, :]
if not data_dependent_kfs:
query_prime = kernel_transformation(query, True, projection_matrix)
key_prime = kernel_transformation(key, False, projection_matrix)
query_prime = kernel_transformation(
data=query, is_query=True, projection_matrix=projection_matrix
)
key_prime = kernel_transformation(
data=key, is_query=False, projection_matrix=projection_matrix
)
else:
query_prime = kernel_transformation(query, key, True, projection_matrix)
key_prime = kernel_transformation(key, query, False, projection_matrix)
Expand Down Expand Up @@ -849,6 +860,8 @@ def masked_favor_attention(query, key, value, masker, mask, kernel_config):

if kernel_config['kernel_transformation'] == 'softmax':
kernel_transformation = exp_softmax_kernel_transformation
elif kernel_config['kernel_transformation'] == 'linear_gaussian':
kernel_transformation = sat.softmax_positive_rfs
else:
if kernel_config['kernel_transformation'] == 'relu':
activation_fn = jax.nn.relu
Expand Down Expand Up @@ -1198,16 +1211,17 @@ def regular_performer_dot_product_attention(

if kernel_config['kernel_transformation'] == 'softmax':
kernel_transformation = exp_softmax_kernel_transformation
elif kernel_config['kernel_transformation'] == 'linear_gaussian':
kernel_transformation = sat.softmax_positive_rfs
else:
if kernel_config['kernel_transformation'] == 'relu':
activation_fn = jax.nn.relu
else:
activation_fn = (lambda x: x * x * x * x)

def gen_transformation(a, b, c):
return generic_kernel_transformation(a, b, c, activation_fn=activation_fn)

kernel_transformation = gen_transformation
kernel_transformation = functools.partial(
generic_kernel_transformation, activation_fn=activation_fn
)
return favor_attention(
query,
key,
Expand Down Expand Up @@ -1337,6 +1351,8 @@ def sharp_masked_performer_dot_product_attention(
del precision
if kernel_config['kernel_transformation'] == 'softmax':
kernel_transformation = exp_softmax_kernel_transformation
elif kernel_config['kernel_transformation'] == 'linear_gaussian':
kernel_transformation = sat.softmax_positive_rfs
else:
if kernel_config['kernel_transformation'] == 'relu':
activation_fn = jax.nn.relu
Expand Down Expand Up @@ -1423,4 +1439,3 @@ def pseudolocal_subquadratic_attention(
)
result = numerator / denominator
return result

10 changes: 6 additions & 4 deletions scenic/projects/performer/subquadratic_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def softmax_positive_rfs(
projection_matrix: jax.Array | None = None,
numerical_stabilizer: float = 0.000001,
is_query: bool = True,

temp: float = 5.0,
) -> jax.Array:
r"""Computes positive random features from https://arxiv.org/abs/2009.14794.
Expand All @@ -73,17 +73,20 @@ def softmax_positive_rfs(
stands for the number of projections.
numerical_stabilizer: small positive constant used for numerical stability.
is_query: determines whether input data tensor is a query- or key-tensor.
temp: temperature parameter of the softmax kernel.
Returns:
Corresponding kernel feature map used to linearize softmax kernel.
"""
h = lambda X: jnp.exp(-0.5 * jnp.sum(jnp.square(X), axis=-1, keepdims=True))
h = lambda X: jnp.exp(
-0.5 * temp * jnp.sum(jnp.square(X), axis=-1, keepdims=True)
)
if is_query:
axis = (-1,)
else:
axis = None
activation_fn = lambda P, X: h(X) * jnp.exp(
P - jnp.max(P, axis=axis, keepdims=True)
temp * (P - jnp.max(P, axis=axis, keepdims=True))
)
return general_kernel_linearization(
data, projection_matrix, numerical_stabilizer, activation_fn
Expand Down Expand Up @@ -128,4 +131,3 @@ def softmax_hyper_positive_rfs(
data, projection_matrix, numerical_stabilizer, negative_activation_fn
)
return jnp.concatenate((positive_exponential, negative_exponential), axis=-1)

32 changes: 17 additions & 15 deletions scenic/projects/pointcloud/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,22 @@ def __call__(self,
output: Tensor of shape [batch_size, num_points, feature_dim]
"""
key_channels = self.key_channels or self.out_channels
input_q = nn.Conv(
key_channels,
kernel_size=(self.kernel_size,),
use_bias=True)(
inputs)
input_k = nn.Conv(
key_channels,
kernel_size=(self.kernel_size,),
use_bias=True)(
inputs)
input_v = nn.Conv(
self.out_channels,
kernel_size=(self.kernel_size,),
use_bias=True)(
inputs)
if (self.attention_fn_configs is not None) and self.attention_fn_configs[
'neighbor_attn'
]:
input_q = coords
input_k = coords
input_v = inputs
else:
input_q = nn.Conv(
key_channels, kernel_size=(self.kernel_size,), use_bias=True
)(inputs)
input_k = nn.Conv(
key_channels, kernel_size=(self.kernel_size,), use_bias=True
)(inputs)
input_v = nn.Conv(
self.out_channels, kernel_size=(self.kernel_size,), use_bias=True
)(inputs)

if (
self.attention_fn_configs is None
Expand Down Expand Up @@ -303,6 +304,7 @@ class PointCloudTransformerClassifier(nn.Module):

@nn.compact
def __call__(self, inputs, train: bool = False, debug: bool = False):
output = inputs
if self.attention_type == 'standard':
output = PointCloudTransformerEncoder(
in_dim=self.in_dim,
Expand Down

0 comments on commit 852f250

Please sign in to comment.