-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
133 lines (105 loc) · 5.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import torch
import transformers
from struq import SupervisedDataset
from config import IGNORE_INDEX, DEFAULT_TOKENS, SPECIAL_DELM_TOKENS, TEXTUAL_DELM_TOKENS
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
window_size: int = field(default=0, metadata={"help": "Window size for the sliding window attention."})
padding_side: str = field(default="right", metadata={"help": "Padding side for tokenization."})
@dataclass
class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."})
@dataclass
class AttackArguments: attack: str = field(default='alpaca', metadata={"help": "Attack type."})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
downsample: Optional[bool] = field(default=True)
lr_scale: Optional[bool] = field(default=True)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def get_embedding_indices(tokenizer):
init_values = [tokenizer.encode(v, add_special_tokens=False)[0] for v in TEXTUAL_DELM_TOKENS]
ignore_values = [i for i in range(len(tokenizer)) if tokenizer.decode(i) == "#"]
return init_values, ignore_values
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
REAL_DELIMITERS_INIT_EMBD_IND, _ = get_embedding_indices(tokenizer)
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens] = input_embeddings_avg
output_embeddings[-num_new_tokens] = output_embeddings_avg
for i in range(len(SPECIAL_DELM_TOKENS)): ### initialize real delimiter's embedding by the existing ones
input_embeddings[-num_new_tokens+i+1] = input_embeddings[REAL_DELIMITERS_INIT_EMBD_IND[i]]
output_embeddings[-num_new_tokens+i+1] = output_embeddings[REAL_DELIMITERS_INIT_EMBD_IND[i]]
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, downsample=True) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, attack=data_args.attack, downsample=downsample)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, AttackArguments))
model_args, data_args, training_args, attack_args = parser.parse_args_into_dataclasses()
data_args.attack = attack_args.attack
print('\n\n' + training_args.output_dir + '\n\n')
print(model_args.model_name_or_path)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
if model_args.window_size > 0:
model.config.window = model_args.window_size
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side=model_args.padding_side,
use_fast=False,
)
special_tokens_dict = dict()
special_tokens_dict["pad_token"] = DEFAULT_TOKENS['pad_token'] ###
special_tokens_dict["eos_token"] = DEFAULT_TOKENS['eos_token']
special_tokens_dict["bos_token"] = DEFAULT_TOKENS['bos_token']
special_tokens_dict["unk_token"] = DEFAULT_TOKENS['unk_token']
special_tokens_dict["additional_special_tokens"] = SPECIAL_DELM_TOKENS ###
smart_tokenizer_and_embedding_resize(special_tokens_dict=special_tokens_dict, tokenizer=tokenizer, model=model)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, downsample=training_args.downsample)
if not training_args.downsample and training_args.lr_scale:
training_args.learning_rate /= data_module["train_dataset"].data_copy_count
trainer = transformers.Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()