forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_opinion.py
215 lines (186 loc) · 8.97 KB
/
train_opinion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import argparse
import os
import random
import time
import numpy as np
import paddle
import paddle.nn.functional as F
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset
from paddlenlp.metrics import ChunkEvaluator
from paddlenlp.transformers import SkepCrfForTokenClassification, SkepModel, SkepTokenizer
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", default='./checkpoint', type=str, help="The output directory where the model checkpoints will be written.")
parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization. "
"Sequences longer than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument("--learning_rate", default=5e-7, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--epochs", default=10, type=int, help="Total number of training epochs to perform.")
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
parser.add_argument("--seed", type=int, default=1000, help="random seed for initialization")
parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
args = parser.parse_args()
# yapf: enable
def set_seed(seed):
"""Sets random seed."""
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
def convert_example_to_feature(example,
tokenizer,
max_seq_len=512,
no_entity_label="O",
is_test=False):
"""
Builds model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. And creates a mask from the two sequences passed
to be used in a sequence-pair classification task.
A skep_ernie_1.0_large_ch/skep_ernie_2.0_large_en sequence has the following format:
::
- single sequence: ``[CLS] X [SEP]``
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
A skep_ernie_1.0_large_ch/skep_ernie_2.0_large_en sequence pair mask has the following format:
::
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
Args:
example(obj:`list[str]`): List of input data, containing text and label if it have label.
tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
which contains most of the methods. Users should refer to the superclass for more information regarding methods.
max_seq_len(obj:`int`): The maximum total input sequence length after tokenization.
Sequences longer than this will be truncated, sequences shorter will be padded.
no_entity_label(obj:`str`, defaults to "O"): The label represents that the token isn't an entity.
is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
Returns:
input_ids(obj:`list[int]`): The list of token ids.
token_type_ids(obj: `list[int]`): List of sequence pair mask.
label(obj:`list[int]`, optional): The input label if not test data.
"""
tokens = example['tokens']
labels = example['labels']
tokenized_input = tokenizer(
tokens,
return_length=True,
is_split_into_words=True,
max_seq_len=max_seq_len)
input_ids = tokenized_input['input_ids']
token_type_ids = tokenized_input['token_type_ids']
seq_len = tokenized_input['seq_len']
if is_test:
return input_ids, token_type_ids, seq_len
else:
labels = labels[:(max_seq_len - 2)]
encoded_label = np.array(
[no_entity_label] + labels + [no_entity_label], dtype="int64")
return input_ids, token_type_ids, seq_len, encoded_label
def create_dataloader(dataset,
mode='train',
batch_size=1,
batchify_fn=None,
trans_fn=None):
if trans_fn:
dataset = dataset.map(trans_fn)
shuffle = True if mode == 'train' else False
if mode == 'train':
batch_sampler = paddle.io.DistributedBatchSampler(
dataset, batch_size=batch_size, shuffle=shuffle)
else:
batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=batch_size, shuffle=shuffle)
return paddle.io.DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=batchify_fn,
return_list=True)
if __name__ == "__main__":
paddle.set_device(args.device)
rank = paddle.distributed.get_rank()
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
train_ds = load_dataset("cote", "dp", splits=['train'])
# The COTE_DP dataset labels with "BIO" schema.
label_map = {label: idx for idx, label in enumerate(train_ds.label_list)}
# `no_entity_label` represents that the token isn't an entity.
no_entity_label_idx = label_map.get("O", 2)
set_seed(args.seed)
skep = SkepModel.from_pretrained('skep_ernie_1.0_large_ch')
model = SkepCrfForTokenClassification(
skep, num_classes=len(train_ds.label_list))
tokenizer = SkepTokenizer.from_pretrained('skep_ernie_1.0_large_ch')
trans_func = partial(
convert_example_to_feature,
tokenizer=tokenizer,
max_seq_len=args.max_seq_length,
no_entity_label=no_entity_label_idx,
is_test=False)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # input ids
Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # token type ids
Stack(dtype='int64'), # sequence lens
Pad(axis=0, pad_val=no_entity_label_idx) # labels
): [data for data in fn(samples)]
train_data_loader = create_dataloader(
train_ds,
mode='train',
batch_size=args.batch_size,
batchify_fn=batchify_fn,
trans_fn=trans_func)
if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
state_dict = paddle.load(args.init_from_ckpt)
model.set_dict(state_dict)
model = paddle.DataParallel(model)
num_training_steps = len(train_data_loader) * args.epochs
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=args.learning_rate,
parameters=model.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
metric = ChunkEvaluator(label_list=train_ds.label_list, suffix=True)
global_step = 0
tic_train = time.time()
for epoch in range(1, args.epochs + 1):
for step, batch in enumerate(train_data_loader, start=1):
input_ids, token_type_ids, seq_lens, labels = batch
loss = model(
input_ids, token_type_ids, seq_lens=seq_lens, labels=labels)
avg_loss = paddle.mean(loss)
global_step += 1
if global_step % 10 == 0 and rank == 0:
print(
"global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s"
% (global_step, epoch, step, avg_loss,
10 / (time.time() - tic_train)))
tic_train = time.time()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if global_step % 100 == 0 and rank == 0:
save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_name = os.path.join(save_dir, "model_state.pdparam")
# Need better way to get inner model of DataParallel
paddle.save(model._layers.state_dict(), file_name)