-
Notifications
You must be signed in to change notification settings - Fork 0
/
idcardClsMain.py
111 lines (95 loc) · 3.92 KB
/
idcardClsMain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import os
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from src.tool import train,test
from src.model import modle
from src.tool.data import *
from torch.hub import load_state_dict_from_url as load_url
import os
from src.tool.test import *
import pickle
# 参数配置
os.environ['TORCH_HOME']='/home/qiangyu/cls/pretrained'
directory = 'data/imagenet'
num_workers = {'train': 8, 'val': 1, 'test': 0}
cls_index_dic = {"ants":0,"bees":1}
ratio = {"train":0.8,"val":0.1,"test":0.1}
# load_weight = None
load_weight = "weight/restnet18_cls2_test_0.625_0.7296200394630432.pt"
data_transforms = {
'train': transforms.Compose([
transforms.RandomRotation(20),
transforms.RandomHorizontalFlip(0.5),
transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5),
transforms.ToTensor(),
transforms.Normalize([0.4802, 0.4481, 0.3975],
[0.2302, 0.2265, 0.2262]),
]),
'val': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4802, 0.4481, 0.3975],
[0.2302, 0.2265, 0.2262]),
]),
'test': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4802, 0.4481, 0.3975],
[0.2302, 0.2265, 0.2262]),
])
}
# 数据集
data_pkl = None
# data_pkl = directory + ".pkl"
# 硬件选择
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# FineTune it
model = modle.restnet18_cls2(True)
if load_weight: model.load_state_dict(torch.load(load_weight))
model = model.to(device)
# Loss
criterion = nn.CrossEntropyLoss()
# #train
# # 数据预处理
# if not data_pkl:
# image_datasets = create_dataset(directory, cls_index_dic, ratio, data_transforms)
# data_loaders = {x: data.DataLoader(image_datasets[x], batch_size=8, shuffle=True, num_workers=num_workers[x])
# for x in ['train', 'val', 'test']}
# dataset_sizes = {x: len(image_datasets[x])
# for x in ['train', 'val', 'test']}
# with open(os.path.join("data","imagenet.pkl"), "wb") as file: # 数据集信息持久化,以便之后测试
# # 使用pickle的dump()函数将变量写入文件
# pickle.dump([image_datasets,data_loaders,dataset_sizes], file)
# else:
# # # 读取数据集划分
# with open("data/imagenet.pkl", "rb") as file:
# # 使用pickle的load()函数加载文件内容
# [image_datasets,data_loaders,dataset_sizes] = pickle.load(file)
# optimazer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 优化器
# num_epochs = 5
# train.train_model(model_name=type(model).__name__, model=model, data_loaders=data_loaders, dataset_sizes=dataset_sizes,
# optimizer=optimazer, criterion=criterion, device=device, num_epochs=num_epochs)
# # test
# best_model_wts = ""
# model.load_state_dict(best_model_wts)
# # 读取数据集划分
# with open("data/imagenet.pkl", "rb") as file:
# # 使用pickle的load()函数加载文件内容
# [image_datasets,data_loaders,dataset_sizes] = pickle.load(file)
# test_model(type(model).__name__, model, data_loaders, dataset_sizes, criterion, device, optimazer, phases=['test'])
# 将 PyTorch 模型转换为 ONNX 格式
# model = torch.load('best.pt')
# model.eval()
# input_names = ['input']
# output_names = ['output']
# x = torch.randn(1,3,224,224,requires_grad=True)
# torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')
# 将pytorch 模型转换为torchscript格式
image = torch.randn(1, 3, 224, 224)
resnet50_traced = torch.jit.trace(model, image)
model(image)
# resnet50_traced.save('/workspace/model/resnet50/model.pt')
torch.jit.save(resnet50_traced, "/home/qiangyu/cls/model_repository/resnet_50/1/model.pt")