-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptions.py
107 lines (91 loc) · 5.24 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import os
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
class Options():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.initialize_parser()
def add_optim_options(self):
self.parser.add_argument('--warmup_steps', type=int, default=1000)
self.parser.add_argument('--total_steps', type=int, default=1000)
self.parser.add_argument('--scheduler_steps', type=int, default=None,
help='total number of step for the scheduler, if None then scheduler_total_step = total_step')
self.parser.add_argument('--accumulation_steps', type=int, default=1)
self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate')
self.parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
self.parser.add_argument('--clip', type=float, default=1., help='gradient clipping')
self.parser.add_argument('--optim', type=str, default='adam')
self.parser.add_argument('--scheduler', type=str, default='fixed')
self.parser.add_argument('--weight_decay', type=float, default=0.1)
self.parser.add_argument('--fixed_lr', action='store_true')
def add_eval_options(self):
self.parser.add_argument('--write_results', action='store_true', help='save results')
self.parser.add_argument('--write_crossattention_scores', action='store_true',
help='save dataset with cross-attention scores')
def add_reader_options(self):
self.parser.add_argument('--train_data', type=str, default='none', help='path of train data')
self.parser.add_argument('--eval_data', type=str, default='none', help='path of eval data')
self.parser.add_argument('--model_size', type=str, default='base')
self.parser.add_argument('--use_checkpoint', action='store_true', help='use checkpoint in the encoder')
self.parser.add_argument('--text_maxlength', type=int, default=200,
help='maximum number of tokens in text segments (question+passage)')
self.parser.add_argument('--answer_maxlength', type=int, default=-1,
help='maximum number of tokens used to train the model, no truncation if -1')
self.parser.add_argument('--no_title', action='store_true',
help='article titles not included in passages')
self.parser.add_argument('--n_descriptions', type=int, default=17)
def initialize_parser(self):
# basic parameters
self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment')
self.parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint/', help='models are saved here')
self.parser.add_argument('--model_path', type=str, default='none', help='path for retraining')
# dataset parameters
self.parser.add_argument("--batch_size_per_gpu", default=1, type=int,
help="Batch size per GPU/CPU for training.")
self.parser.add_argument('--maxload', type=int, default=-1)
self.parser.add_argument("--local_rank", type=int, default=-1,
help="For distributed training: local_rank")
self.parser.add_argument("--main_port", type=int, default=-1,
help="Main port (for multi-node SLURM jobs)")
self.parser.add_argument('--seed', type=int, default=0, help="random seed for initialization")
# training parameters
self.parser.add_argument('--eval_freq', type=int, default=2000,
help='evaluate model every <eval_freq> steps during training')
self.parser.add_argument('--save_freq', type=int, default=5000,
help='save model every <save_freq> steps during training')
self.parser.add_argument('--eval_print_freq', type=int, default=1000,
help='print intermdiate results of evaluation every <eval_print_freq> steps')
def print_options(self, opt):
message = '\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default_value = self.parser.get_default(k)
if v != default_value:
comment = f'\t(default: {default_value})'
message += f'{str(k):>30}: {str(v):<40}{comment}\n'
expr_dir = Path(opt.checkpoint_dir) / opt.name
model_dir = expr_dir / 'models'
model_dir.mkdir(parents=True, exist_ok=True)
with open(expr_dir / 'opt.log', 'wt') as opt_file:
opt_file.write(message)
opt_file.write('\n')
logger.info(message)
def parse(self):
opt = self.parser.parse_args()
return opt
def get_options(use_reader=False,
use_retriever=False,
use_optim=False,
use_eval=False):
options = Options()
if use_reader:
options.add_reader_options()
if use_retriever:
options.add_retriever_options()
if use_optim:
options.add_optim_options()
if use_eval:
options.add_eval_options()
return options.parse()