-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_bart_to_longformerencoderdecoder.py
153 lines (129 loc) · 6.35 KB
/
convert_bart_to_longformerencoderdecoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import argparse
import logging
import os
import copy
from transformers import MBart50TokenizerFast as BartTokenizer
from transformers import MBartForConditionalGeneration as BartForConditionalGeneration
# from transformers.models.bart.modeling_bart import shift_tokens_right
from longformer.longformer_encoder_decoder import LongformerSelfAttentionForBart, LongformerEncoderDecoderConfig
from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def create_long_model(
save_model_to,
base_model,
tokenizer_name_or_path,
attention_window,
max_pos
):
model = BartForConditionalGeneration.from_pretrained(base_model)
tokenizer = BartTokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos)
config = LongformerEncoderDecoderConfig.from_pretrained(base_model)
model.config = config
# in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention
# expects attention_probs_dropout_prob, so set it here
config.attention_probs_dropout_prob = config.attention_dropout
config.architectures = ['LongformerEncoderDecoderForConditionalGeneration', ]
# extend position embeddings
tokenizer.model_max_length = max_pos
tokenizer.init_kwargs['model_max_length'] = max_pos
current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape
assert current_max_pos == config.max_position_embeddings + 2
config.max_encoder_position_embeddings = max_pos
config.max_decoder_position_embeddings = config.max_position_embeddings
del config.max_position_embeddings
max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2
assert max_pos >= current_max_pos
# allocate a larger position embedding matrix for the encoder
new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)
# copy position embeddings over and over to initialize the new position embeddings
k = 2
step = current_max_pos - 2
while k < max_pos - 1:
new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:]
k += step
model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed
# allocate a larger position embedding matrix for the decoder
# new_decoder_pos_embed = model.model.decoder.embed_positions.weight.new_empty(max_pos, embed_size)
# # copy position embeddings over and over to initialize the new position embeddings
# k = 2
# step = current_max_pos - 2
# while k < max_pos - 1:
# new_decoder_pos_embed[k:(k + step)] = model.model.decoder.embed_positions.weight[2:]
# k += step
# model.model.decoder.embed_positions.weight.data = new_decoder_pos_embed
# replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`
config.attention_window = [attention_window] * config.num_hidden_layers
config.attention_dilation = [1] * config.num_hidden_layers
for i, layer in enumerate(model.model.encoder.layers):
longformer_self_attn_for_bart = LongformerSelfAttentionForBart(config, layer_id=i)
longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj
longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj
longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj
longformer_self_attn_for_bart.longformer_self_attn.query_global = copy.deepcopy(layer.self_attn.q_proj)
longformer_self_attn_for_bart.longformer_self_attn.key_global = copy.deepcopy(layer.self_attn.k_proj)
longformer_self_attn_for_bart.longformer_self_attn.value_global = copy.deepcopy(layer.self_attn.v_proj)
longformer_self_attn_for_bart.output = layer.self_attn.out_proj
layer.self_attn = longformer_self_attn_for_bart
logger.info(f'saving model to {save_model_to}')
model.save_pretrained(save_model_to)
tokenizer.save_pretrained(save_model_to)
return model, tokenizer
def main():
parser = argparse.ArgumentParser(description="Convert BART to LongBART. Replaces BART encoder's SelfAttnetion with LongformerSelfAttention")
parser.add_argument(
'--base_model',
type=str,
default='facebook/bart-large',
help='The name or path of the base model you want to convert'
)
parser.add_argument(
'--tokenizer_name_or_path',
type=str,
default='facebook/bart-large',
help='The name or path of the tokenizer'
)
parser.add_argument(
'--save_model_to',
type=str,
required=True,
help='The path to save the converted model'
)
parser.add_argument(
'--attention_window',
type=int,
default=128,
help='attention window size for longformer self attention (one sided)'
)
parser.add_argument(
'--max_pos',
type=int,
default=4096,
help='maximum encoder positions'
)
args = parser.parse_args()
if not os.path.exists(args.save_model_to):
os.mkdir(args.save_model_to)
create_long_model(
save_model_to=args.save_model_to,
base_model=args.base_model,
tokenizer_name_or_path=args.tokenizer_name_or_path,
attention_window=args.attention_window,
max_pos=args.max_pos
)
# tokenizer = BartTokenizer.from_pretrained(args.save_model_to)
# TXT = "My friends are <mask> but they eat too many carbs."
# model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(args.save_model_to)
# model.model.encoder.config.gradient_checkpointing = True
# model.model.decoder.config.gradient_checkpointing = True
# data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048)
# input_ids = data['input_ids']
# attention_mask = data['attention_mask']
# decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id)
# logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0]
# masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
# probs = logits[0, masked_index].softmax(dim=0)
# values, predictions = probs.topk(5)
# print(tokenizer.convert_ids_to_tokens(predictions))
if __name__ == "__main__":
main()