Skip to content

Commit

Permalink
Merge pull request #5 from TomFrederik/develop/0.1.6
Browse files Browse the repository at this point in the history
Develop/0.1.6
  • Loading branch information
TomFrederik authored Mar 16, 2022
2 parents 5183827 + a16da0f commit 20a8070
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 45 deletions.
2 changes: 1 addition & 1 deletion unseal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import *

__version__ = '0.1.5'
__version__ = '0.1.6'
1 change: 0 additions & 1 deletion unseal/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from . import common_hooks
from . import util
from . import rome_hooks
from .commons import Hook, HookedModel
52 changes: 42 additions & 10 deletions unseal/hooks/common_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def gpt_attn_wrapper(
c_proj: torch.Tensor,
vocab_embedding: torch.Tensor,
target_ids: torch.Tensor,
batch_size: int = 16,
batch_size: Optional[int] = None,
) -> Tuple[Callable, Callable]:
"""Wraps around the [AttentionBlock]._attn function to save the individual heads' logits.
This is necessary because the individual heads' logits are not available on a module level and thus not accessible via a hook.
Expand All @@ -225,17 +225,21 @@ def gpt_attn_wrapper(
:type vocab_matrix: torch.Tensor
:param target_ids: indices of the target tokens for which the logits are computed
:type target_ids: torch.Tensor
:param batch_size: batch size to reduce compute cost
:type batch_size: int
:param batch_size: batch size to reduce memory footprint, defaults to None
:type batch_size: Optional[int]
:return: inner, func, the wrapped function and the original function
:rtype: Tuple[Callable, Callable]
"""
# TODO Find a smarter/more efficient way of implementing this function
# TODO clean up this function
def inner(query, key, value, *args, **kwargs):
nonlocal c_proj
nonlocal target_ids
nonlocal vocab_embedding
nonlocal batch_size
attn_output, attn_weights = func(query, key, value, *args, **kwargs)
if batch_size is None:
batch_size = attn_output.shape[0]
with torch.no_grad():
temp = attn_weights[...,None] * value[:,:,None]
if len(c_proj.shape) == 2:
Expand All @@ -244,21 +248,49 @@ def inner(query, key, value, *args, **kwargs):
temp = temp[0,:,:-1] # could this be done earlier?
new_temp = []
for head in tqdm(range(temp.shape[0])):
new_temp.append([])
for i in range(math.ceil(temp.shape[1] / batch_size)):
out = temp[head, i*batch_size:(i+1)*batch_size] @ c_proj[head]
out = out @ vocab_embedding # compute logits
out -= out.mean(dim=-1, keepdim=True) # center logits
# select targets
out = out[...,torch.arange(len(target_ids)), target_ids].to('cpu')
new_temp.append(out)
new_temp = torch.cat(new_temp, dim=0)
new_temp[-1].append(out)

# center logits
new_temp[-1] = torch.cat(new_temp[-1])
new_temp[-1] -= torch.mean(new_temp[-1], dim=-1, keepdim=True)
# select targets
new_temp[-1] = new_temp[-1][...,torch.arange(len(target_ids)), target_ids]#.to('cpu')

new_temp = torch.cat(new_temp, dim=0)
new_temp = einops.rearrange(new_temp, '(h t1) t2 -> h t1 t2', h=temp.shape[0], t1=len(target_ids), t2=len(target_ids))
max_pos_value = torch.amax(new_temp).item()
max_neg_value = torch.amax(-new_temp).item()

save_ctx['logits'] = {
'pos': (new_temp/max_pos_value).clamp(min=0, max=1).detach(),
'neg': (new_temp/max_neg_value).clamp(min=-1, max=0).detach(),
'pos': (new_temp/max_pos_value).clamp(min=0, max=1).detach().cpu(),
'neg': (-new_temp/max_neg_value).clamp(min=0, max=1).detach().cpu(),
}
return attn_output, attn_weights
return inner, func

#TODO update docs here
def additive_output_noise(
indices: str,
mean: Optional[float] = 0,
std: Optional[float] = 0.1
) -> Callable:
slice_ = util.create_slice_from_str(indices)
def func(save_ctx, input, output):
noise = mean + std * torch.randn_like(output[slice_])
output[slice_] += noise
return output
return func

def hidden_patch_hook_fn(
position: int,
replacement_tensor: torch.Tensor,
) -> Callable:
indices = "...," + str(position) + len(replacement_tensor.shape) * ",:"
inner = replace_activation(indices, replacement_tensor)
def func(save_ctx, input, output):
output[0][...] = inner(save_ctx, input, output[0])
return output
35 changes: 4 additions & 31 deletions unseal/hooks/rome_hooks.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,7 @@
# Some pre-implemented hooking functions to reproduce the experiments
# from the ROME paper: https://openreview.net/forum?id=mMECu_poAs
#TODO for version 1.0.0: Remove this file
import logging

# Some hooks that I've only used in the context of the ROME paper so far --> possibly migrate this to the unseal_experiments repo
from typing import Callable, Optional
from .common_hooks import additive_output_noise, hidden_patch_hook_fn

import torch
logging.warning("rome_hooks.py is deprecated and will be removed in version 1.0.0. Please use common_hooks.py instead.")

from . import util
from . import common_hooks


def additive_output_noise(
indices: str,
mean: Optional[float] = 0,
std: Optional[float] = 0.1
) -> Callable:
slice_ = util.create_slice_from_str(indices)
def func(save_ctx, input, output):
noise = mean + std * torch.randn_like(output[slice_])
output[slice_] += noise
return output
return func

def hidden_patch_hook_fn(
position: int,
replacement_tensor: torch.Tensor,
) -> Callable:
indices = "...," + str(position) + len(replacement_tensor.shape) * ",:"
inner = common_hooks.replace_activation(indices, replacement_tensor)
def func(save_ctx, input, output):
output[0][...] = inner(save_ctx, input, output[0])
return output
return func
2 changes: 1 addition & 1 deletion unseal/visuals/streamlit_interfaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ def text_change(col_idx: Union[int, List[int]]):
out_proj_name = out_proj_name,
attn_suffix = attn_suffix,
unembedding_key = unembedding_key,
layer = layer,
layer_id = layer,
)
3 changes: 2 additions & 1 deletion unseal/visuals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def compute_attn_logits(
attn_suffix: Optional[str] = None,
unembedding_key: Optional[str] = 'lm_head',
layer_id: Optional[int] = None,
batch_size: Optional[int] = None,
):
# parse inputs
if save_path is None:
Expand Down Expand Up @@ -57,7 +58,7 @@ def compute_attn_logits(

# wrap the _attn function to create logit attribution
model.save_ctx[f'logit_layer_{layer}'] = dict()
old_fn = wrap_gpt_attn(model, layer, target_ids, unembedding_key, attn_name, attn_suffix, layer_key_prefix, out_proj_name)
old_fn = wrap_gpt_attn(model, layer, target_ids, unembedding_key, attn_name, attn_suffix, layer_key_prefix, out_proj_name, batch_size)

# forward pass
model.forward(model_input, hooks=[])
Expand Down

0 comments on commit 20a8070

Please sign in to comment.