forked from polyusmart/Personalized-Hashtag-Preferences
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
147 lines (119 loc) · 7.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from utils.loss import weighted_class_bceloss
from utils.loss import vae_loss_function
from utils.loss import l1_penalty
from tqdm import tqdm
def fix_model(model):
for param in model.parameters():
param.requires_grad = False
def unfix_model(model):
for param in model.parameters():
param.requires_grad = True
def train_mlp(mlp_model, train_dataset, valid_dataset, train_dataloader, valid_dataloader, optimizer, weights, scheduler, joint_train_flag, epoch_num):
# train the model
# epoch = args.epoch
best_valid_loss = 1e10
best_valid_acc = []
for epoch in range(epoch_num):
num_positive, num_negative = 0., 0.
num_correct_positive, num_correct_negative = 0, 0
total_loss = 0.
mlp_model.train()
for train_user_features, train_user_lens, train_hashtag_features, train_hashtag_lens, train_bow_hashtag_features, labels in tqdm(train_dataloader):
if torch.cuda.is_available():
train_user_features = train_user_features.cuda()
train_user_lens = train_user_lens.cuda()
train_hashtag_features = train_hashtag_features.cuda()
train_hashtag_lens = train_hashtag_lens.cuda()
train_bow_hashtag_features = train_bow_hashtag_features.cuda()
labels = labels.cuda()
# train process-----------------------------------
optimizer.zero_grad()
if (not joint_train_flag):
fix_model(mlp_model.get_vae_model())
weight, topic_words, pred_labels, recon_batch, data_bow, mu, logvar = mlp_model(joint_train_flag, train_user_features, train_user_lens, train_hashtag_features, train_hashtag_lens, train_bow_hashtag_features)
mlp_loss = weighted_class_bceloss(pred_labels, labels.reshape(-1, 1), weights)
loss = 100*mlp_loss
total_loss += (loss.item() * len(labels))
else:
unfix_model(mlp_model.get_vae_model())
# forward pass
weight, topic_words, pred_labels, recon_batch, data_bow, mu, logvar = mlp_model(joint_train_flag, train_user_features, train_user_lens, train_hashtag_features, train_hashtag_lens, train_bow_hashtag_features)
# compute loss
mlp_loss = weighted_class_bceloss(pred_labels, labels.reshape(-1, 1), weights)
vae_loss = vae_loss_function(recon_batch, data_bow, mu, logvar)
vae_loss = vae_loss + mlp_model.get_vae_model().l1_strength * l1_penalty(mlp_model.get_vae_model().fcd1.weight)
loss = 100*mlp_loss + vae_loss/100
total_loss += (loss.item() * len(labels))
for pred_label, label in zip(pred_labels, labels.reshape(-1, 1)):
if label == 1:
num_positive += 1
if pred_label > 0.95:
num_correct_positive += 1
else:
num_negative += 1
if pred_label < 0.95:
num_correct_negative += 1
# backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(mlp_model.parameters(), max_norm=20.0, norm_type=2)
optimizer.step()
print('train positive_acc: %f train negative_acc: %f train_loss: %f' % \
((num_correct_positive / num_positive), (num_correct_negative / num_negative), (total_loss / len(train_dataset))))
num_positive, num_negative = 0., 0.
num_correct_positive, num_correct_negative = 0, 0
total_loss = 0.
# best_model = Mlp(768, 256)
# best_model.load_state_dict(torch.load())
# if torch.cuda.is_available():
# best_model = best_model.cuda()
mlp_model.eval()
with torch.no_grad():
for user_features, user_lens, hashtag_features, hashtag_lens, bow_hashtag_features, labels in tqdm(valid_dataloader):
if torch.cuda.is_available():
user_features = user_features.cuda()
user_lens = user_lens.cuda()
hashtag_features = hashtag_features.cuda()
hashtag_lens = hashtag_lens.cuda()
bow_hashtag_features = bow_hashtag_features.cuda()
labels = labels.cuda()
if (not joint_train_flag) :
fix_model(mlp_model.get_vae_model())
weight, topic_words, pred_labels, recon_batch, data_bow, mu, logvar = mlp_model(joint_train_flag, user_features, user_lens, hashtag_features, hashtag_lens, bow_hashtag_features)
mlp_loss = weighted_class_bceloss(pred_labels, labels.reshape(-1, 1), weights)
loss = 100*mlp_loss
total_loss += (loss.item() * len(labels))
else:
unfix_model(mlp_model.get_vae_model())
weight, topic_words, pred_labels, recon_batch, data_bow, mu, logvar = mlp_model(joint_train_flag, user_features, user_lens, hashtag_features, hashtag_lens, bow_hashtag_features)
mlp_loss = weighted_class_bceloss(pred_labels, labels.reshape(-1, 1), weights)
vae_loss = vae_loss_function(recon_batch, data_bow, mu, logvar)
vae_loss = vae_loss + mlp_model.get_vae_model().l1_strength * l1_penalty(mlp_model.get_vae_model().fcd1.weight)
loss = 100*mlp_loss + vae_loss/100
total_loss += (loss.item() * len(labels))
for pred_label, label in zip(pred_labels, labels.reshape(-1, 1)):
if label == 1:
num_positive += 1
if pred_label > 0.95:
num_correct_positive += 1
else:
num_negative += 1
if pred_label < 0.95:
num_correct_negative += 1
print('epoch: %d valid positive_acc: %f valid negative_acc: %f valid_loss: %f' % \
((epoch + 1), (num_correct_positive / num_positive), (num_correct_negative / num_negative), (total_loss / len(valid_dataset))))
scheduler.step(total_loss / len(valid_dataset))
print('learning rate: %f' % optimizer.param_groups[0]['lr'])
best_valid_acc.append(num_correct_positive / num_positive + num_correct_negative / num_negative)
if total_loss < best_valid_loss:
best_valid_loss = total_loss
best_epoch = epoch
print('Current best!')
torch.save(mlp_model.state_dict(), './ModelRecords'+f'/best_model_{epoch}.pt')
torch.save(mlp_model.state_dict(), './ModelRecords'+f'/model_{epoch}.pt')
# choose the best epoch after the epoch 50 best on the acccuracy metric of validation set
try:
best_epoch = best_valid_acc.index(max(best_valid_acc[50:]))
except:
best_epoch = 0
return best_epoch