-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
136 lines (119 loc) · 4.43 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from config import StrictConfigParser
from trainer import build_trainer
from model import DefinitionProbing
from data import get_dm_conf, DataMaker
from modules import get_pretrained_transformer
import os
import torch
import torch.nn as nn
from dotmap import DotMap
import hashlib
import json
from embeddings import Word2Vec
config_parser = StrictConfigParser(default=os.path.join("config", "config.yaml"))
if __name__ == "__main__":
config = config_parser.parse_args()
use_cuda = config.device == "cuda" and torch.cuda.is_available()
torch.manual_seed(config.seed)
example_field = get_dm_conf(config.encoder, "example")
word_field = get_dm_conf(config.encoder, "word")
definition_field = get_dm_conf("normal", "definition")
data_fields = [example_field, word_field, definition_field]
if config.variational or config.defbert:
definition_ae_field = get_dm_conf(config.encoder, "definition_ae")
data_fields.append(definition_ae_field)
device = torch.device("cuda" if use_cuda else "cpu")
config.update(
{
"serialization_dir": config.serialization_dir
+ config.dataset
+ "/"
+ hashlib.sha224(
json.dumps(dict(config.to_dict()), sort_keys=True).encode()
).hexdigest()[:6]
}
)
############### DATA ###############
datamaker = DataMaker(data_fields, config.datapath)
datamaker.build_data(
config.dataset,
max_len=config.max_length,
lowercase=config.lowercase,
shared_vocab_fields=["example", "word"],
)
####################################
####################################
############### MODEL ##############
embeddings = DotMap(
{
"tgt": nn.Embedding.from_pretrained(
Word2Vec(datamaker.vocab.definition.itos),
freeze=False,
padding_idx=datamaker.vocab.definition.stoi["<pad>"],
)
}
)
embeddings.tgt.unk_idx, embeddings.tgt.padding_idx = (
datamaker.vocab.definition.stoi["<unk>"],
datamaker.vocab.definition.stoi["<pad>"],
)
dropout = DotMap(
{
"src": {
"input": config.src_input_dropout,
"output": config.src_output_dropout,
},
"tgt": {
"input": config.tgt_input_dropout,
"output": config.tgt_output_dropout,
},
"tgt_word_dropout": config.tgt_word_dropout,
"src_word_dropout": config.src_word_dropout,
}
)
encoder = get_pretrained_transformer(config.encoder)
if config.variational or config.defbert:
if config.tied:
definition_encoder = encoder
else:
definition_encoder = get_pretrained_transformer(config.encoder)
else:
definition_encoder = None
model = DefinitionProbing(
encoder=encoder,
encoder_pretrained=True,
encoder_frozen=config.encoder_frozen,
decoder_hidden=config.decoder_hidden,
embeddings=embeddings,
max_layer=config.max_layer,
src_pad_idx=datamaker.vocab.example.pad_token_id,
teacher_forcing_p=config.teacher_forcing_p,
attentional=config.attentional,
aggregator=config.aggregator,
variational=config.variational,
latent_size=config.latent_size,
word_dropout_p=config.tgt_word_dropout,
definition_encoder=definition_encoder,
decoder_num_layers=config.decoder_num_layers,
).to(config.device)
####################################
####################################
########## TRAINING LOOP ###########
trainer = build_trainer(model, config, datamaker)
with open(config.serialization_dir + "/config.json", "w") as f:
json.dump(dict(config.to_dict()), f)
with open(config.serialization_dir + "/model_architecture", "w") as f:
f.write(
repr(model)
+ "\nParameter Count:"
f" {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
)
try:
for i in range(config.max_epochs):
train_out = trainer._train(config.train_batch_size)
if train_out is None:
break
valid_out = trainer._validate(config.valid_batch_size)
test_out = trainer._test(config.valid_batch_size)
except KeyboardInterrupt:
print("Stopping training, train counter =", trainer._train_counter)