-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
92 lines (80 loc) · 3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from model_unet import Unet
from data_preprocess import *
from tqdm import tqdm
import numpy as np
import skimage.io as io
PATH = './model/unet_model.pt'
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# mask只需要转换为tensor
y_transforms = transforms.ToTensor()
def train_model(model, criterion, optimizer, dataload, num_epochs=10):
best_model = model
min_loss = 1000
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
# print('-' * 10)
dt_size = len(dataload.dataset)
epoch_loss = 0
step = 0
for x, y in tqdm(dataload):
step += 1
inputs = x.to(device)
labels = y.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
if (epoch_loss/step) < min_loss:
min_loss = (epoch_loss/step)
best_model = model
torch.save(best_model.state_dict(), PATH)
return best_model
# 训练模型
def train():
model = Unet(1, 1).to(device)
batch_size = 1
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())
train_dataset = TrainDataset(r"D:\2PM_2023\segmentation\data\00000008自动核质比训练集new-all\TrainDataset\StratumSpinosum\train", transform=x_transforms,target_transform=y_transforms)
dataloaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
train_model(model, criterion, optimizer, dataloaders)
# 保存模型的输出结果
# def test():
# model = Unet(1, 2)
# model.load_state_dict(torch.load(PATH))
# test_dataset = TestDataset("dataset/test", transform=x_transforms,target_transform=y_transforms)
# dataloaders = DataLoader(test_dataset, batch_size=1)
# model.eval()
# import matplotlib.pyplot as plt
# plt.ion()
# with torch.no_grad():
# for index, x in enumerate(dataloaders):
# y = model(x)
# img_y = torch.squeeze(y).numpy()
# img_y = img_y[:, :, np.newaxis]
# img = labelVisualize(2, COLOR_DICT, img_y) if False else img_y[:, :, 0]
# io.imsave("./dataset/test/" + str(index) + "_predict.tif", img)
# plt.pause(0.01)
# plt.show()
if __name__ == '__main__':
print("开始训练")
train()
# print("训练完成,保存模型")
# print("-"*20)
# print("开始预测")
# test()