-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
118 lines (100 loc) · 4.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
import argparse
from utils.params import Params
from utils.preprocess import load_audio, prepare_data
from sklearn.model_selection import train_test_split
import torch.optim as optim
import torch.nn as nn
import torch
from tqdm import tqdm
import numpy as np
from models.model_manager import getModel
from sklearn.metrics import classification_report
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"model_name",
type=str,
help="Pass name of model as defined in hparams.yaml."
)
args = parser.parse_args()
# check if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load hyperparameters
params = Params("hparams.yaml", args.model_name)
# Load & split data
inputs,targets = prepare_data(params)
inputs_train,inputs_test,targets_train,targets_test = train_test_split(inputs,targets,test_size=0.3)
inputs_train = torch.Tensor(inputs_train).to(device)
targets_train = torch.LongTensor(targets_train).to(device)
inputs_test = torch.Tensor(inputs_test).to(device)
targets_test = torch.LongTensor(targets_test).to(device)
# Train model
model = getModel(params, inputs_train.shape, 10).to(device)
optimizer = optim.Adam(model.parameters(), lr=params.lr)
criterion = nn.CrossEntropyLoss()
# Set up training parameters
num_epochs = params.num_epochs
batch_size = params.batch_size
patience = params.patience # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf') # Initialize best validation loss
patience_counter = 0 # Counter to keep track of the number of epochs without improvement
log = {}
# Training loop
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for i in tqdm(range(0, len(inputs_train), batch_size)):
# Get batch of inputs and targets
inputs_batch = inputs_train[i:i+batch_size]
targets_batch = targets_train[i:i+batch_size]
# Forward pass
outputs = model(inputs_batch)
loss = criterion(outputs, targets_batch)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Accumulate loss
running_loss += loss.item()
# calculate training loss, validation loss, accuracy
train_loss = running_loss / (len(inputs_train) // batch_size)
model.eval()
with torch.no_grad():
outputs = model(inputs_test)
val_loss = criterion(outputs, targets_test)
_, predicted = torch.max(outputs, 1)
correct = (predicted == targets_test).sum().item()
accuracy = correct / len(targets_test)
# save log
log[epoch] = {
'train_loss': train_loss,
'val_loss': val_loss.item(),
}
# Check for early stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# Save checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, params.checkpoint_path)
else:
patience_counter += 1
if patience_counter > patience:
print("Early stopping")
break
print(f"EPOCH {epoch + 1}, train_loss: {train_loss}, val_loss: {val_loss.item()}, val_acc: {accuracy}")
print('Finished Training')
print("Performance on test set:")
print(classification_report(targets_test.detach().cpu().numpy(), torch.argmax(model(inputs_test), dim=1).detach().cpu().numpy()))
# Save log
log_path = params.log_path
with open(log_path, 'w') as f:
json.dump(log, f)
print(f"Log saved at {log_path}")
if __name__ == "__main__":
main()