-
Notifications
You must be signed in to change notification settings - Fork 1
/
Prepare_Data.py
121 lines (92 loc) · 5.54 KB
/
Prepare_Data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
"""
Create datasets and dataloaders for models
"""
## Python standard libraries
from __future__ import print_function
from __future__ import division
import pdb
from Datasets.Split_Data import DataSplit
import ssl
## PyTorch dependencies
import torch
## Local external libraries
from Datasets.Pytorch_Datasets import *
from Datasets.Get_transform import *
from barbar import Bar
def Compute_Mean_STD(trainloader):
print('Computing Mean/STD')
'Code from: https://stackoverflow.com/questions/60101240/finding-mean-and-standard-deviation-across-image-channels-pytorch'
nimages = 0
mean = 0.0
var = 0.0
for i_batch, batch_target in enumerate(Bar(trainloader)):
batch = batch_target[0]
# Rearrange batch to be the shape of [B, C, W * H]
batch = batch.view(batch.size(0), batch.size(1), -1)
# Update total number of images
nimages += batch.size(0)
# Compute mean and std here
mean += batch.mean(2).sum(0)
var += batch.var(2).sum(0)
mean /= nimages
var /= nimages
std = torch.sqrt(var)
print()
return mean, std
def Prepare_DataLoaders(Network_parameters, split):
ssl._create_default_https_context = ssl._create_unverified_context
Dataset = Network_parameters['Dataset']
data_dir = Network_parameters['data_dir']
global data_transforms
data_transforms = get_transform(Network_parameters, input_size=224)
if Dataset == "LeavesTex":
train_dataset = LeavesTex1200(data_dir,transform=data_transforms["train"])
val_dataset = LeavesTex1200(data_dir,transform=data_transforms["test"])
test_dataset = LeavesTex1200(data_dir,transform=data_transforms["test"])
#Create train/val/test loader
split = DataSplit(train_dataset,val_dataset,test_dataset, shuffle=False,random_seed=split)
train_loader, val_loader , test_loader = split.get_split(batch_size=Network_parameters['batch_size']['train'],
num_workers=Network_parameters['num_workers'],
show_sample=False,
val_batch_size=Network_parameters['batch_size']['val'],
test_batch_size=Network_parameters['batch_size']['test'])
dataloaders_dict = {'train': train_loader,'val': val_loader,'test': test_loader}
elif Dataset == "PlantVillage":
train_dataset = PlantVillage(data_dir,transform=data_transforms["train"])
val_dataset = PlantVillage(data_dir,transform=data_transforms["test"])
test_dataset = PlantVillage(data_dir,transform=data_transforms["test"])
#Create train/val/test loader based on mean and std
split = DataSplit(train_dataset,val_dataset,test_dataset, shuffle=False,random_seed=split)
train_loader, val_loader , test_loader = split.get_split(batch_size=Network_parameters['batch_size']['train'],
num_workers=Network_parameters['num_workers'],
show_sample=False,
val_batch_size=Network_parameters['batch_size']['val'],
test_batch_size=Network_parameters['batch_size']['test'])
dataloaders_dict = {'train': train_loader,'val': val_loader,'test': test_loader}
elif Dataset == "DeepWeeds":
train_dataset = DeepWeeds(data_dir,transform=data_transforms["train"])
val_dataset = DeepWeeds(data_dir,transform=data_transforms["test"])
test_dataset = DeepWeeds(data_dir,transform=data_transforms["test"])
#Create train/val/test loader based on mean and std
split = DataSplit(train_dataset,val_dataset,test_dataset, shuffle=False,random_seed=split)
train_loader, val_loader , test_loader = split.get_split(batch_size=Network_parameters['batch_size']['train'],
num_workers=Network_parameters['num_workers'],
show_sample=False,
val_batch_size=Network_parameters['batch_size']['val'],
test_batch_size=Network_parameters['batch_size']['test'])
dataloaders_dict = {'train': train_loader,'val': val_loader,'test': test_loader}
else:
raise RuntimeError('{} Dataset not implemented'.format(Dataset))
if Dataset=='LeavesTex' or Dataset=='DeepWeeds':
pass
else:
image_datasets = {'train': train_dataset, 'val': val_dataset, 'test': test_dataset}
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=Network_parameters['batch_size'][x],
num_workers=Network_parameters['num_workers'],
pin_memory=Network_parameters['pin_memory'],
shuffle=False,
)
for x in ['train', 'val','test']}
return dataloaders_dict