Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Update detr attention #47

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:

jobs:
build:
name: Build and Test Colossal-AI
name: Build and Test Titans
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
Expand All @@ -23,7 +23,7 @@ jobs:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Install Colossal-AI
run: |
pip install colossalai==0.1.4+torch1.10cu11.3 -f https://release.colossalai.org
pip install colossalai==0.1.7+torch1.10cu11.3 -f https://release.colossalai.org
pip install -v .
pip install -r requirements/requirements-test.txt
- name: Unit Testing
Expand Down
2 changes: 1 addition & 1 deletion titans/layer/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .gpt_attention import GPTSelfAttention
from .detr_attention import DeTrCrossAttention
from .detr_attention import DeTrAttention
from .vit_attention import ViTSelfAttention
from .vit_moe_attention import SelfAttentionForMoe
from .transformer_attention import TransformerSelfAttention, TransformerMultiHeadAttention
49 changes: 30 additions & 19 deletions titans/layer/attention/detr_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from colossalai import nn as col_nn
from ..init_rules import init_rules
from titans.decorator import no_support
# This part need to work together with the col_nn.Linear (row, col) in order to better parallelize.


@no_support(['sp'])
class DeTrCrossAttention(nn.Module):
class DeTrAttention(nn.Module):

def __init__(self,
hidden_size: int,
Expand All @@ -25,46 +25,57 @@ def __init__(self,
hidden_size,
dtype=dtype,
bias=bias,
)
self.key_value = col_nn.Linear1D_Col(hidden_size,
2 * hidden_size,
**init_rules[init_method]['transformer'])
self.key = col_nn.Linear1D_Col(hidden_size,
hidden_size,
dtype=dtype,
bias=bias,
**init_rules[init_method]['transformer'])
self.value = col_nn.Linear1D_Col(hidden_size,
hidden_size,
dtype=dtype,
bias=bias,
)
**init_rules[init_method]['transformer'])
self.attention_dropout = col_nn.Dropout(attention_dropout)
self.dense = col_nn.Linear1D_Row(hidden_size, hidden_size, dtype=dtype, bias=True)
self.dropout = col_nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x, memory):
q = self.query(x)
kv = self.key_value(memory)
all_head_size = kv.shape[-1] // 2
def forward(self, q, k, v, attn_mask=None, key_padding_mask=None):
q = self.query(q)
k = self.key(k)
v = self.value(v)

all_head_size = q.shape[-1]
num_attention_heads = all_head_size // self.attention_head_size

new_q_shape = q.shape[:-1] + (num_attention_heads, self.attention_head_size)
q = q.view(new_q_shape)
q = q.permute((0, 2, 1, 3))
q = q.permute((2, 3, 0, 1)) # ?

new_kv_shape = kv.shape[:-1] + (num_attention_heads, 2 * self.attention_head_size)
kv = kv.view(new_kv_shape)
kv = kv.permute((0, 2, 1, 3))
k, v = torch.chunk(kv, 2, dim=-1)
k = k.permute((2, 3, 0, 1)) # ?
v = v.permute((2, 3, 0, 1)) # ?
new_k_shape = k.shape[:-1] + (num_attention_heads, self.attention_head_size)
k = k.view(new_k_shape)
k = k.permute((0, 2, 1, 3))

new_v_shape = v.shape[:-1] + (num_attention_heads, self.attention_head_size)
v = v.view(new_v_shape)
v = v.permute((0, 2, 1, 3))

x = torch.matmul(q, k.transpose(-1, -2))
x = x / math.sqrt(self.attention_head_size)

# if attn_mask is not None:
# x += attn_mask

x = self.softmax(x)
x = self.attention_dropout(x)

x = torch.matmul(x, v)
x = x.transpose(1, 2)
new_context_layer_shape = x.size()[:-2] + (all_head_size,)
# the size of x after reshape is (BATCH_SZIE, SEQ_LEN, HIDDEN_SIZE)
x = x.reshape(new_context_layer_shape)
x = x.transpose(0, 1)

# the size of x after dense is (BATCH_SZIE, SEQ_LEN, HIDDEN_SIZE)
x = self.dense(x)
x = self.dropout(x)

Expand Down
27 changes: 16 additions & 11 deletions titans/layer/block/detr_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from colossalai.nn.layer.utils import CheckpointModule
from torch import dtype, nn

from titans.layer.attention import ViTSelfAttention, DeTrCrossAttention
from titans.layer.attention import DeTrAttention
from titans.layer.mlp import ViTMLP
from titans.decorator import support_tp_pp_only

Expand All @@ -29,7 +29,7 @@ def __init__(self,
init_method: str = 'torch'):
super().__init__(checkpoint)
self.norm1 = col_nn.LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, dtype=dtype)
self.attn = ViTSelfAttention(hidden_size=hidden_size,
self.attn = DeTrAttention(hidden_size=hidden_size,
num_heads=num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
Expand All @@ -46,10 +46,12 @@ def __init__(self,
bias=bias,
init_method=init_method)

def _forward(self, x):
x = x + self.drop_path(self.norm1(self.attn(x)))
def _forward(self, x, attn_mask=None, key_padding_mask=None):
# input dimension [b,s,h]
x = x.transpose(0,1)
x = x + self.drop_path(self.norm1(self.attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
return x
return x.transpose(0,1)


@support_tp_pp_only()
Expand All @@ -73,15 +75,15 @@ def __init__(self,
self.norm2 = col_nn.LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, dtype=dtype)
self.norm3 = col_nn.LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, dtype=dtype)

self.attn1 = ViTSelfAttention(hidden_size=hidden_size,
self.attn1 = DeTrAttention(hidden_size=hidden_size,
num_heads=num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
bias=bias,
dtype=dtype,
init_method=init_method)

self.attn2 = DeTrCrossAttention(hidden_size=hidden_size,
self.attn2 = DeTrAttention(hidden_size=hidden_size,
num_heads=num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
Expand All @@ -99,8 +101,11 @@ def __init__(self,
bias=bias,
init_method=init_method)

def _forward(self, x, memory):
x = x + self.drop_path(self.norm1(self.attn1(x)))
x = x + self.drop_path(self.norm2(self.attn2(x, memory)))
def _forward(self, x, memory, self_attn_mask=None, self_attn_key_padding_mask=None, multihead_attn_mask=None, multihead_attn_key_padding_mask=None):
# input dimension [b,s,h] [q,s,h]
x = x.transpose(0,1)
memory = memory.transpose(0,1)
x = x + self.drop_path(self.norm1(self.attn1(x, x, x, attn_mask=self_attn_mask, key_padding_mask=self_attn_key_padding_mask)))
x = x + self.drop_path(self.norm2(self.attn2(x, memory, memory, attn_mask=multihead_attn_mask, key_padding_mask=multihead_attn_key_padding_mask)))
x = x + self.drop_path(self.mlp(self.norm3(x)))
return x
return x.transpose(0,1)