-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_xlnet.py
32 lines (28 loc) · 1.07 KB
/
train_xlnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from models import xlnet
import torch
import os
from utils import config
from dataset import get_train_test_loaders
if __name__ == "__main__":
torch.set_warn_always(False)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
train_loader, test_loader, num_labels, weights = get_train_test_loaders("xlnet", num_files=2, get_weights=True)
EPOCHS = config("train.epochs")
lr = config("train.lr")
print("Training model...")
if not os.path.exists("checkpoints"):
os.makedirs("checkpoints")
model = xlnet.train_model(train_loader,
test_loader,
num_labels,
epochs=EPOCHS,
freeze_num=3,
learning_rate=lr,
weights=weights
)
print("Model trained.")
PATH = f"trained_models/XLNet-cased-{EPOCHS}"
if not os.path.exists("trained_models"):
os.makedirs("trained_models")
torch.save(model.state_dict(), PATH)
print(f"Model saved to {PATH}")