-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_helper.py
147 lines (113 loc) · 5.76 KB
/
train_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# -*- coding: utf-8 -*-
"""
@Time : 2019/1/30 14:01
@Author : MaCan ([email protected])
@File : train_helper.py
"""
import argparse
import os
__all__ = ['get_args_parser']
def get_args_parser():
parser = argparse.ArgumentParser()
# if os.name == 'nt':
# bert_path = 'F:\chinese_L-12_H-768_A-12'
# root_path = r'C:\workspace\python\BERT-BiLSTM-CRF-NER'
# else:
# bert_path = '/home/macan/ml/data/chinese_L-12_H-768_A-12/'
# root_path = '/home/macan/ml/workspace/BERT-BiLSTM-CRF-NER'
# 第一组参数
group1 = parser.add_argument_group('File Paths',
'config the path, checkpoint and filename of a pretrained/fine-tuned BERT model')
group1.add_argument('-data_dir', type=str,
default="./zhejiang/data_ner",
help='train.csv, dev.csv and test.csv 数据文件存放路径')
group1.add_argument('-bert_config_file', type=str,
default=r"D:\projects_py\bert\chinese_L-12_H-768_A-12\bert_config.json")
group1.add_argument('-output_dir', type=str,
default=r"./zhejiang/output_ner",
help='模型训练后保存路径')
group1.add_argument('-init_checkpoint', type=str,
default=r"./chinese_L-12_H-768_A-12", # 初始bert模型
# default=r"D:\projects_py\bert\zhejiang\output", # 训练后的模型
help='Initial checkpoint (usually from a pre-trained BERT model).')
group1.add_argument('-vocab_file', type=str,
default=r"./chinese_L-12_H-768_A-12/vocab.txt")
# 第二组关于模型的一些参数
group2 = parser.add_argument_group('Model Config', 'config the model params')
group2.add_argument('-max_seq_length', type=int,
default=128,
# default=64,
help='输入序列的允许最大长度,即句子 tokens 的长度')
group2.add_argument('-do_train', action='store_false',
default=False,
# default=False,
help='Whether to run training.')
group2.add_argument('-do_eval', action='store_false',
default=False,
# default=False,
help='Whether to run eval on the dev set.')
group2.add_argument('-do_predict', action='store_false',
default=True,
help='Whether to run the predict in inference mode on the test set.')
group2.add_argument('-batch_size', type=int,
# default=1, # for test
default=32,
help='Total batch size for training, eval and predict.')
group2.add_argument('-learning_rate', type=float,
default=1e-5,
help='The initial learning rate for Adam.')
group2.add_argument('-num_train_epochs', type=float,
default=10,
help='Total number of training epochs to perform.')
group2.add_argument('-dropout_rate', type=float,
default=0.5,
# default=0.0,
help='Dropout rate')
group2.add_argument('-clip', type=float,
default=0.5,
help='Gradient clip')
group2.add_argument('-warmup_proportion', type=float,
default=0.1,
help='Proportion of training to perform linear learning rate warmup for '
'E.g., 0.1 = 10% of training.')
group2.add_argument('-lstm_size', type=int,
default=None,
help='size of lstm units.')
group2.add_argument('-num_layers', type=int,
default=0,
help='number of rnn layers, default is 1.')
group2.add_argument('-cell', type=str,
default='lstm',
help='which rnn cell used.')
group2.add_argument('-save_checkpoints_steps', type=int,
default=500,
help='save_checkpoints_steps')
group2.add_argument('-save_summary_steps', type=int,
default=10,
help='save_summary_steps.')
group2.add_argument('-filter_adam_var', type=bool,
default=False,
help='训练完之后是否删除adam的参数,不存储在model中')
group2.add_argument('-do_lower_case', type=bool,
default=True,
help='Whether to lower case the input text.')
group2.add_argument('-clean', type=bool,
default=False,
# default=True,
help="是否清除output路径下面文件, 继续训练模型请设置为False")
group2.add_argument('-device_map', type=str,
# default='0', # GPU
default='-1', # CPU
help='witch device using to train')
# add labels
group2.add_argument('-label_list', type=str,
default=None,
# default=r"/output_ner/label_list.pkl",
help='User define labels, can be a file with one label one line or a string using \',\' split')
parser.add_argument('-verbose', action='store_true',
default=False,
help='turn on tensorflow logging for debug')
parser.add_argument('-ner', type=str,
default='ner',
help='which modle to train')
return parser.parse_args()