forked from guanwaiguren/Pytorch-Human-Pose-Estimation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
34 lines (24 loc) · 845 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import builder
import trainer
import os
import time
import argparse
from opts import opts
opts = opts().parse()
torch.set_default_tensor_type('torch.DoubleTensor' if opts.usedouble else 'torch.FloatTensor')
Builder = builder.Builder(opts)
Model = Builder.Model()
Optimizer = Builder.Optimizer(Model)
Loss = Builder.Loss()
Metrics = Builder.Metric()
TrainDataLoader, ValDataLoader = Builder.DataLoaders()
Epoch = Builder.Epoch()
Model = Model.to(opts.gpuid)
# opts.saveDir = os.path.join(opts.saveDir, os.path.join(opts.model, 'logs_{}'.format(datetime.datetime.now().isoformat())))
File = os.path.join(opts.saveDir, 'log.txt')
Trainer = trainer.Trainer(Model, Optimizer, Loss, Metrics, File, None, opts)
if opts.test:
Trainer.test(ValDataLoader)
exit()
Trainer.train(TrainDataLoader, ValDataLoader, Epoch, opts.nEpochs)