Skip to content

Commit

Permalink
make _attn wrapper much faster
Browse files Browse the repository at this point in the history
  • Loading branch information
TomFrederik committed Mar 16, 2022
1 parent 4df9596 commit 905594e
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions unseal/hooks/common_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def gpt_attn_wrapper(
c_proj: torch.Tensor,
vocab_embedding: torch.Tensor,
target_ids: torch.Tensor,
batch_size: int = 16,
batch_size: Optional[int] = None,
) -> Tuple[Callable, Callable]:
"""Wraps around the [AttentionBlock]._attn function to save the individual heads' logits.
This is necessary because the individual heads' logits are not available on a module level and thus not accessible via a hook.
Expand All @@ -225,17 +225,21 @@ def gpt_attn_wrapper(
:type vocab_matrix: torch.Tensor
:param target_ids: indices of the target tokens for which the logits are computed
:type target_ids: torch.Tensor
:param batch_size: batch size to reduce compute cost
:type batch_size: int
:param batch_size: batch size to reduce memory footprint, defaults to None
:type batch_size: Optional[int]
:return: inner, func, the wrapped function and the original function
:rtype: Tuple[Callable, Callable]
"""
# TODO Find a smarter/more efficient way of implementing this function
# TODO clean up this function
def inner(query, key, value, *args, **kwargs):
nonlocal c_proj
nonlocal target_ids
nonlocal vocab_embedding
nonlocal batch_size
attn_output, attn_weights = func(query, key, value, *args, **kwargs)
if batch_size is None:
batch_size = attn_output.shape[0]
with torch.no_grad():
temp = attn_weights[...,None] * value[:,:,None]
if len(c_proj.shape) == 2:
Expand All @@ -244,21 +248,26 @@ def inner(query, key, value, *args, **kwargs):
temp = temp[0,:,:-1] # could this be done earlier?
new_temp = []
for head in tqdm(range(temp.shape[0])):
new_temp.append([])
for i in range(math.ceil(temp.shape[1] / batch_size)):
out = temp[head, i*batch_size:(i+1)*batch_size] @ c_proj[head]
out = out @ vocab_embedding # compute logits
out -= out.mean(dim=-1, keepdim=True) # center logits
# select targets
out = out[...,torch.arange(len(target_ids)), target_ids].to('cpu')
new_temp.append(out)
new_temp = torch.cat(new_temp, dim=0)
new_temp[-1].append(out)

# center logits
new_temp[-1] = torch.cat(new_temp[-1])
new_temp[-1] -= torch.mean(new_temp[-1], dim=-1, keepdim=True)
# select targets
new_temp[-1] = new_temp[-1][...,torch.arange(len(target_ids)), target_ids]#.to('cpu')

new_temp = torch.cat(new_temp, dim=0)
new_temp = einops.rearrange(new_temp, '(h t1) t2 -> h t1 t2', h=temp.shape[0], t1=len(target_ids), t2=len(target_ids))
max_pos_value = torch.amax(new_temp).item()
max_neg_value = torch.amax(-new_temp).item()

save_ctx['logits'] = {
'pos': (new_temp/max_pos_value).clamp(min=0, max=1).detach(),
'neg': (new_temp/max_neg_value).clamp(min=-1, max=0).detach(),
'pos': (new_temp/max_pos_value).clamp(min=0, max=1).detach().cpu(),
'neg': (-new_temp/max_neg_value).clamp(min=0, max=1).detach().cpu(),
}
return attn_output, attn_weights
return inner, func
Expand Down

0 comments on commit 905594e

Please sign in to comment.