-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_math.py
201 lines (166 loc) · 6.44 KB
/
train_math.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import torch
import random
import pathlib
import numpy as np
from dataclasses import dataclass, field
from typing import Optional, Dict, List
import transformers
from trl import SFTTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, deepspeed
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
SEED = 42
os.environ["TOKENIZERS_PARALLELISM"] = "false"
INS_PROMPT = '''
Dưới đây là hướng dẫn mô tả bài toán. \
Viết câu trả lời hoàn thành yêu cầu một cách thích hợp\n\
### Câu hỏi: {question}\n{choices}\n\
### Trả lời: {answer}
'''
# Reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
@dataclass
class LoraArguments:
lora_enable:bool = False
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
val_nums: int = field(default=None, metadata={"help": "Number of samples in validation set"})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=None,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
overwrite_output_dir: bool = field(default=True)
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict)
def maybe_zero_3(param):
if hasattr(param, "ds_id"):
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
# Instruction
def formatting_func(sample):
if sample['explanation']:
ans = sample['explanation'] + '\n' + sample['answer']
else:
ans = sample['answer']
choices = '\n'.join(sample['choices'])
text = INS_PROMPT.format(question=sample['question'], choices=choices, answer=ans)
return text
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, LoraArguments))
model_args, data_args, training_args, lora_args = parser.parse_args_into_dataclasses()
json_data = load_dataset('json', data_files=data_args.data_path, field='data')
if data_args.val_nums is None:
train_data = json_data['train']
val_data = None
else:
dataset = json_data['train'].train_test_split(test_size=data_args.val_nums, shuffle=True, seed=42)
train_data = dataset['train']
val_data = dataset['test']
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
if lora_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.local_rank == 0:
print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
packing=True,
tokenizer=tokenizer,
formatting_func=formatting_func
)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
if deepspeed.is_deepspeed_zero3_enabled():
state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
if training_args.local_rank == 0:
state_dict = state_dict_zero3
else:
# in other mode we use original code from fastchat team, to make sure our change is minimum
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), lora_args.lora_bias
)
if training_args.local_rank == 0:
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
if __name__ == "__main__":
train()