-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_multi_scale_multi_loss.py
53 lines (43 loc) · 1.93 KB
/
predict_multi_scale_multi_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch, configargparse
from data import load_asap_data
from model_architechure_bert_multi_scale_multi_loss import DocumentBertScoringModel
def _initialize_arguments(p: configargparse.ArgParser):
p.add('--bert_model_path', help='bert_model_path')
p.add('--efl_encode', action='store_true', help='is continue training')
p.add('--r_dropout', help='r_dropout', type=float)
p.add('--batch_size', help='batch_size', type=int)
p.add('--bert_batch_size', help='bert_batch_size', type=int)
p.add('--cuda', action='store_true', help='use gpu or not')
p.add('--device')
p.add('--model_directory', help='model_directory')
p.add('--test_file', help='test data file')
p.add('--data_dir', help='data directory to store asap experiment data')
p.add('--data_sample_rate', help='data_sample_rate', type=float)
p.add('--prompt', help='prompt')
p.add('--fold', help='fold')
p.add('--chunk_sizes', help='chunk_sizes', type=str)
p.add('--result_file', help='pred result file path', type=str)
args = p.parse_args()
args.test_file = "%s/p8_fold3_test.txt" % args.data_dir
args.model_directory = "%s/%s_%s" % (args.model_directory, args.prompt, args.fold)
args.bert_model_path = args.model_directory
if torch.cuda.is_available() and args.cuda:
args.device = 'cuda'
else:
args.dev = 'cpu'
return args
if __name__ == "__main__":
# initialize arguments
p = configargparse.ArgParser(default_config_files=["asap.ini"])
args = _initialize_arguments(p)
print(args)
# load data
test = load_asap_data(args.test_file)
test_documents, test_labels = [], []
for _, text, label in test:
test_documents.append(text)
test_labels.append(label)
print("sample number:", len(test_documents))
print("label number:", len(test_labels))
model = DocumentBertScoringModel(args=args)
model.predict_for_regress((test_documents, test_labels))