forked from huggingface/naacl_transfer_learning_tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfinetuning_model.py
78 lines (66 loc) · 3.89 KB
/
finetuning_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import importlib
import torch
import torch.nn as nn
from pretraining_model import Transformer, TransformerWithLMHead
class TransformerWithAdapters(Transformer):
def __init__(self, adapters_dim, embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers, dropout):
""" Transformer with adapters (small bottleneck layers) """
super().__init__(embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers, dropout)
self.adapters_1 = nn.ModuleList()
self.adapters_2 = nn.ModuleList()
for _ in range(num_layers):
self.adapters_1.append(nn.Sequential(nn.Linear(embed_dim, adapters_dim),
nn.ReLU(),
nn.Linear(adapters_dim, embed_dim)))
self.adapters_2.append(nn.Sequential(nn.Linear(embed_dim, adapters_dim),
nn.ReLU(),
nn.Linear(adapters_dim, embed_dim)))
def forward(self, x):
""" Input has shape [seq length, batch] """
positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
h = self.tokens_embeddings(x)
h = h + self.position_embeddings(positions).expand_as(h)
h = self.dropout(h)
attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
attn_mask = torch.triu(attn_mask, diagonal=1)
for l in range(len(self.layer_norm_1)):
h = self.layer_norms_1[l](h)
x, _ = self.attentions[l](h, h, h, attn_mask=attn_mask, need_weights=False)
x = self.dropout(x)
x = self.adapters_1[l](x) + x # Add an adapter after attention
h = x + h
h = self.layer_norms_2[l](h)
x = self.feed_forwards[l](h)
x = self.dropout(x)
x = self.adapters_2[l](x) + x # Add an adapter after feed-forward
h = x + h
return h
class TransformerWithClfHead(TransformerWithLMHead):
def __init__(self, config, fine_tuning_config):
""" Transformer with a classification head and a language modeling head on top and optionally adapters. """
super().__init__(config)
self.config = fine_tuning_config
if fine_tuning_config.adapters_dim > 0:
self.transformer = TransformerWithAdapters(fine_tuning_config.adapters_dim, config.embed_dim, config.hidden_dim,
config.num_embeddings, config.num_max_positions, config.num_heads,
config.num_layers, fine_tuning_config.dropout, causal=not config.mlm)
self.classification_head = nn.Linear(config.embed_dim, fine_tuning_config.num_classes)
self.apply(self.init_weights)
def forward(self, x, clf_tokens_mask, lm_labels=None, clf_labels=None, padding_mask=None):
""" Input has shape [seq length, batch] """
hidden_states = self.transformer(x, padding_mask)
lm_logits = self.lm_head(hidden_states)
clf_tokens_states = (hidden_states * clf_tokens_mask.unsqueeze(-1).float()).sum(dim=0)
clf_logits = self.classification_head(clf_tokens_states)
loss = []
if clf_labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
loss.append(loss_fct(clf_logits.view(-1, clf_logits.size(-1)), clf_labels.view(-1)))
if lm_labels is not None:
shift_logits = lm_logits[:-1] if self.transformer.causal else lm_logits
shift_labels = lm_labels[1:] if self.transformer.causal else lm_labels
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
loss.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)))
if len(loss):
return (lm_logits, clf_logits), loss
return lm_logits, clf_logits