-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c83e9f5
commit c267dc8
Showing
184 changed files
with
13,772 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
|
||
import torch.nn as nn | ||
|
||
|
||
class Adaptation(object): | ||
def __init__(self, flag_adapt_bias: bool, flag_adapt_weight: bool = True, | ||
merge_weights: bool = True, freeze_weight: bool = True, freeze_bias: bool = True): | ||
assert isinstance(self, nn.Module) | ||
self.flag_adapt_bias = flag_adapt_bias | ||
self.flag_adapt_weight = flag_adapt_weight | ||
self.weight.requires_grad = not freeze_weight | ||
if self.bias is not None: | ||
self.bias.requires_grad = not freeze_bias | ||
# Mark the weight as unmerged | ||
self.merged = False | ||
self.merge_weights = merge_weights | ||
|
||
def assign_adaptation(self, adaptation): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
# ------------------------------------------------------------------------------------------ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. | ||
# ------------------------------------------------------------------------------------------ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import transformers | ||
|
||
from adapter.module import ssf, up | ||
from adapter.module.up import LayerNorm, BatchNorm1d, RevIN, _RevIN | ||
|
||
|
||
def add_down_up_(parent_module: nn.Module, module_name: str, freeze_weight: bool, | ||
merge_weights=False, load_weights=True, **kwargs): | ||
old_module = getattr(parent_module, module_name) | ||
if isinstance(old_module, nn.Linear): | ||
new_module = Linear(in_features=old_module.in_features, out_features=old_module.out_features, | ||
bias=old_module.bias is not None, freeze_weight=freeze_weight, | ||
device=old_module.weight.device, dtype=old_module.weight.dtype, | ||
merge_weights=merge_weights, **kwargs) | ||
elif isinstance(old_module, transformers.Conv1D): | ||
new_module = TFMConv1D(nx=old_module.weight.shape[0], nf=old_module.nf, merge_weights=merge_weights, | ||
freeze_weight=freeze_weight, **kwargs) | ||
elif isinstance(old_module, _RevIN) and old_module.affine: | ||
new_module = RevIN(num_features=old_module.num_features, eps=old_module.eps, affine=True, | ||
subtract_last=old_module.subtract_last, | ||
merge_weights=merge_weights, freeze_weight=freeze_weight, **kwargs) | ||
elif isinstance(old_module, nn.BatchNorm1d) and old_module.affine: | ||
new_module = BatchNorm1d(num_features=old_module.num_features, eps=old_module.eps, | ||
momentum=old_module.momentum, affine=old_module.affine, | ||
track_running_stats=old_module.track_running_stats, | ||
device=old_module.weight.device, dtype=old_module.weight.dtype, | ||
merge_weights=merge_weights, freeze_weight=freeze_weight, **kwargs) | ||
elif isinstance(old_module, nn.LayerNorm) and len(old_module.normalized_shape) == 1 and old_module.elementwise_affine: | ||
new_module = LayerNorm(normalized_shape=old_module.normalized_shape[-1], eps=old_module.eps, | ||
elementwise_affine=True, device=old_module.weight.device, dtype=old_module.weight.dtype, | ||
merge_weights=merge_weights, freeze_weight=freeze_weight, **kwargs) | ||
elif isinstance(old_module, nn.Conv1d): | ||
new_module = Conv1d(old_module.in_channels, old_module.out_channels, | ||
kernel_size=old_module.kernel_size, stride=old_module.stride, padding=old_module.padding, | ||
dilation=old_module.dilation, groups=old_module.groups, bias=old_module.bias is not None, | ||
padding_mode=old_module.padding_mode, | ||
device=old_module.weight.device, dtype=old_module.weight.dtype, | ||
freeze_weight=freeze_weight, merge_weights=merge_weights, **kwargs) | ||
else: | ||
raise NotImplementedError | ||
if load_weights: | ||
new_module.load_state_dict(old_module.state_dict(), strict=False) | ||
setattr(parent_module, module_name, new_module) | ||
|
||
|
||
class Down_Up(up.Adaptation_Up): | ||
def __init__(self, in_features: int, *args, **kwargs): | ||
up.Adaptation_Up.__init__(self, *args, **kwargs) | ||
self.in_features = in_features | ||
self.register_buffer('scale2', None, persistent=False) | ||
|
||
def assign_adaptation(self, adaptation): | ||
if adaptation is None: | ||
self.scale, self.scale2, self.shift = None, None, None | ||
else: | ||
self.scale = adaptation[..., :self.out_features] + 1 | ||
self.scale2 = adaptation[..., self.out_features:self.out_features + self.in_features] + 1 | ||
if self.flag_adapt_bias: | ||
self.shift = adaptation[..., -self.out_features:] | ||
if self.scale.dim() == 2: | ||
self.scale = self.scale.unsqueeze(1) | ||
self.scale2 = self.scale2.unsqueeze(1) | ||
self.shift = self.shift.unsqueeze(1) | ||
|
||
def _merge(self, weight, bias): | ||
weight, bias = super()._merge(weight, bias) | ||
if weight is not None and self.flag_adapt_weight: | ||
scale2 = self.scale2.squeeze() | ||
if self.fan_in_fan_out: | ||
weight = weight * scale2.reshape(scale2.shape[-1:] + (1,) * (weight.dim() - 1)) | ||
else: | ||
weight = weight * scale2.reshape((1, scale2.shape[-1]) + (1,) * (weight.dim() - 2)) | ||
return weight, bias | ||
|
||
def _ssf_input(self, x: torch.Tensor): | ||
batch_size = x.size()[:-1] | ||
x = x.reshape(self.scale2.shape[0], -1, x.shape[-1]) | ||
return (x * self.scale2).view(*batch_size, x.shape[-1]) | ||
|
||
def _ssf(self, res: torch.Tensor): | ||
batch_size = res.size()[:-1] | ||
res = res.view(self.scale.shape[0], -1, res.shape[-1]) | ||
if self.bias is not None: | ||
res = res * self.scale + (self.shift + self.bias) | ||
else: | ||
res = res * self.scale | ||
return res.view(*batch_size, res.shape[-1]) | ||
|
||
|
||
class Linear(Down_Up, ssf.Linear): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
bias: bool = True, | ||
device=None, dtype=None, | ||
merge_weights: bool = True, freeze_weight: bool = True, | ||
**kwargs | ||
): | ||
nn.Linear.__init__(self, in_features, out_features, bias=bias, device=device, dtype=dtype) | ||
Down_Up.__init__(self, in_features=in_features, out_features=out_features, flag_adapt_bias=bias, | ||
merge_weights=merge_weights, freeze_weight=freeze_weight, **kwargs) | ||
|
||
def forward(self, x: torch.Tensor): | ||
if not hasattr(self, 'scale') or self.scale is None: | ||
return nn.Linear.forward(self, x) | ||
if self.merged: | ||
return F.linear(x, self.weight, bias=self.bias) | ||
elif self.scale.shape[0] == 1 and self.merge_weights: | ||
weight, bias = self._merge(self.weight, self.bias) | ||
return F.linear(x, weight, bias=bias) | ||
else: | ||
return self._ssf(F.linear(self._ssf_input(x), self.weight, bias=None)) | ||
|
||
|
||
class TFMConv1D(Down_Up, ssf.AttnConv1D): | ||
def __init__( | ||
self, | ||
nx: int, | ||
nf: int, | ||
merge_weights: bool = True, freeze_weight: bool = True, | ||
**kwargs | ||
): | ||
transformers.Conv1D.__init__(self, nx=nx, nf=nf) | ||
Down_Up.__init__(self, in_features=nx, out_features=nf, flag_adapt_bias=True, fan_in_fan_out=True, | ||
merge_weights=merge_weights, freeze_weight=freeze_weight, **kwargs) | ||
|
||
def forward(self, x: torch.Tensor): | ||
if not hasattr(self, 'scale') or self.scale is None: | ||
return transformers.Conv1D.forward(self, x) | ||
if self.merged: | ||
size_out = x.size()[:-1] + (self.nf,) | ||
return torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight).view(size_out) | ||
elif self.scale.shape[0] == 1 and self.merge_weights: | ||
weight, bias = self._merge(self.weight, self.bias) | ||
return F.linear(x, weight.transpose(-1, -2), bias=bias) | ||
else: | ||
return self._ssf(F.linear(self._ssf_input(x), self.weight.transpose(-1, -2), bias=None)) | ||
|
||
|
||
class Conv1d(Down_Up, ssf.Conv1d): | ||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, | ||
stride=1, padding=0, dilation=1, groups: int = 1, bias: bool = True, | ||
padding_mode: str = 'zeros', device=None, dtype=None, | ||
merge_weights: bool = True, freeze_weight: bool = True, **kwargs): | ||
nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, | ||
groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype) | ||
Down_Up.__init__(self, in_features=in_channels, out_features=out_channels, flag_adapt_bias=bias, | ||
merge_weights=merge_weights, freeze_weight=freeze_weight, **kwargs) | ||
|
||
def forward(self, x): | ||
if not hasattr(self, 'scale') or self.scale is None: | ||
return nn.Conv1d.forward(self, x) | ||
if self.merged: | ||
return self._conv_forward(x, self.weight, bias=self.bias) | ||
elif self.scale.shape[0] == 1 and self.merge_weights: | ||
weight, bias = self._merge(self.weight, self.bias) | ||
return self._conv_forward(x, weight, bias=bias) | ||
else: | ||
return self._ssf(self._conv_forward(self._ssf_input(x), self.weight, bias=None)) | ||
|
||
def _ssf_input(self, x: torch.Tensor): | ||
batch_size = x.size()[:-2] | ||
x = x.view(self.scale2.shape[0], -1, *x.shape[-2:]) * self.scale2.unsqueeze(-1) | ||
return x.view(*batch_size, *x.shape[-2:]) | ||
|
||
def _ssf(self, res: torch.Tensor): | ||
batch_size = res.size()[:-2] | ||
res = res.view(self.scale.shape[0], -1, *res.shape[-2:]) | ||
if self.bias is not None: | ||
res = res * self.scale.unsqueeze(-1) + (self.shift + self.bias).unsqueeze(-1) | ||
else: | ||
res = res * self.scale.unsqueeze(-1) | ||
return res.view(*batch_size, *res.shape[-2:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import collections | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
import transformers | ||
import typing | ||
|
||
from adapter.module import down_up | ||
from adapter.module.base import Adaptation | ||
|
||
|
||
def clip(x, max_norm=1): | ||
x_norm = torch.norm(x, dim=-1, keepdim=True) | ||
scale = torch.clip(max_norm / x_norm, max=1) | ||
return x * scale, x_norm | ||
|
||
|
||
class Bottleneck(nn.Module): | ||
def __init__(self, in_dim: int, out_dim: int, n_layers: int, | ||
bottleneck_dim: typing.Union[typing.List[int], int], activation=nn.LeakyReLU, need_bias: bool = False, | ||
shared=False, rand_init: dict = None): | ||
super().__init__() | ||
self.out_dim = out_dim | ||
self.activation = activation() | ||
self.shared = shared | ||
self.need_bias = need_bias | ||
self.biases = nn.ParameterList([nn.Parameter(torch.empty(n_layers, 1, bottleneck_dim))]) | ||
self.weights = nn.ParameterList([nn.Linear(in_dim, bottleneck_dim, bias=False) if self.shared else | ||
nn.Parameter(torch.empty(n_layers, 1, bottleneck_dim, in_dim))]) | ||
if self.need_bias: | ||
self.biases.append(nn.Parameter(torch.zeros(n_layers, 1, out_dim))) | ||
self.weights.append(nn.Parameter(torch.zeros(1 if self.shared else n_layers, 1, out_dim, bottleneck_dim))) | ||
for i in range(0, len(self.weights) - 1): | ||
for j in range(n_layers): | ||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out( | ||
self.weights[i].weight if shared else self.weights[i][0, 0]) | ||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | ||
nn.init.uniform_(self.biases[i][j], -bound, bound) | ||
if not self.shared: | ||
nn.init.kaiming_uniform_(self.weights[i][j, 0], a=math.sqrt(5)) | ||
if self.need_bias: | ||
nn.init.zeros_(self.biases[-1]) | ||
if rand_init is not None: | ||
for pos, (r, in_fea) in rand_init.items(): | ||
nn.init.kaiming_uniform_(self.biases[-1][pos, :, :r * in_fea], a=math.sqrt(5)) | ||
|
||
self.memories = None | ||
self.last_adaptation = None | ||
|
||
def forward(self, x, mask=None, training=False): | ||
if self.shared: | ||
x = self.activation(self.weights[0](x) + self.biases[0]).unsqueeze(-1) | ||
else: | ||
x = self.activation(self.weights[0] @ x.unsqueeze(-1) + self.biases[0].unsqueeze(-1)) | ||
x = self.weights[-1] @ x | ||
x = x.squeeze(-1) | ||
if training and mask is not None: | ||
x = x * mask.unsqueeze(-1) | ||
if self.need_bias: | ||
x = x + self.biases[-1] | ||
return x | ||
|
||
|
||
class AdaptGenerator(nn.Module): | ||
def __init__(self, backbone: nn.Module, concept_features: int, activation=nn.LeakyReLU, | ||
shared: bool = True, need_bias: bool = True, adaptive_dim: bool = False, mid_dim: int = None): | ||
super().__init__() | ||
self.dim_name_dict = collections.defaultdict(list) | ||
self.bottlenecks = nn.ModuleDict() | ||
self.loras = nn.ModuleDict() | ||
self.memories = None | ||
self.memories_lora = None | ||
for name, module in backbone.named_modules(): | ||
if isinstance(module, Adaptation): | ||
_weight = module.affine_weight if hasattr(module, 'affine_weight') else module.weight | ||
out_features = _weight.shape[1 if isinstance(module, transformers.Conv1D) else 0] | ||
if _weight.dim() == 1: | ||
in_features = 1 | ||
else: | ||
in_features = _weight.shape[0 if isinstance(module, transformers.Conv1D) else 1] | ||
out_dim = out_features | ||
if module.bias is not None: | ||
out_dim += out_features | ||
if isinstance(module, down_up.Down_Up): | ||
out_dim += in_features | ||
self.dim_name_dict[name.split('.')[-1] + '_' + str(out_dim)].append(name) | ||
|
||
for key, names in self.dim_name_dict.items(): | ||
out_dim = int(key.split('_')[-1]) | ||
_hid_dims = min(mid_dim, out_dim // 4) if adaptive_dim else mid_dim | ||
self.bottlenecks[key] = Bottleneck(concept_features, out_dim, len(names), _hid_dims, | ||
activation, shared=shared, need_bias=need_bias) | ||
|
||
def forward(self, x, need_clip=False): | ||
if need_clip: | ||
x, x_norm = clip(x) | ||
coefs = {k: bottleneck(x) for k, bottleneck in self.bottlenecks.items()} | ||
return coefs |
Oops, something went wrong.