-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathMetaLearning.py
218 lines (170 loc) · 8.42 KB
/
MetaLearning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
from copy import deepcopy
from torch.utils.data import DataLoader
from AutoEncoder import *
from DataAugment import limit_size, do_nothing
from Datasets import *
from VAE_GAN_train import load_image_datasets
def augment_tensor_dataset(tensor_dataset):
augmented_tensors = []
for tensor in tensor_dataset:
# augmented_images = augment_image_tensor(tensor)
augmented_images = do_nothing(tensor)
for i in range(len(augmented_images)):
augmented_images[i] = limit_size(augmented_images[i])
augmented_tensors.extend(augmented_images)
return torch.stack(augmented_tensors) # 将列表转换回一个新的张量
class MAML:
def __init__(self, inner_lr, beta, d):
self.device = d
encoder = VAEEncoder(latent_dim)
decoder = VAEDecoder(latent_dim)
self.model = VAEModel(encoder, decoder).to(self.device)
self.inner_lr = inner_lr
self.beta = beta
self.grad_clip_norm = 1.0 # 添加梯度裁剪的范数值
def inner_update(self, x):
x_sample, z_mu, z_var = self.model(x)
inner_loss = compute_task_loss(x_sample, x, z_mu, z_var)
self.model.zero_grad()
inner_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_clip_norm) # 梯度裁剪
inner_optimizer = optim.Adam(self.model.parameters(), lr=self.inner_lr)
inner_optimizer.step()
return inner_loss
def meta_update(self, num_tasks, general_loader, specific_loader, num_samples_per_task):
# 将模型参数转换为浮点类型
for name, param in self.model.named_parameters():
if param.data.dtype != torch.float32:
param.data = param.data.float()
param_dict = deepcopy(self.model.state_dict())
param_dict = {name: torch.zeros_like(param_dict[name], dtype=torch.float32, requires_grad=True) for name in
param_dict}
for _ in range(num_tasks):
x_task = sample_task_data(general_loader, specific_loader, num_samples_per_task)
x_task_viewed = x_task.view(-1, input_channel, 512, 512)
self.inner_update(x_task_viewed)
updated_param = deepcopy(self.model.state_dict())
x_query = sample_task_data(general_loader, specific_loader, num_samples_per_task)
x_query_viewed = x_query.view(-1, input_channel, 512, 512)
self.model.load_state_dict(updated_param)
x_sample, z_mu, z_var = self.model(x_query_viewed)
task_loss = compute_task_loss(x_sample, x_task_viewed, z_mu, z_var)
print('\r\rtast_loss:', task_loss.item())
self.model.zero_grad()
task_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_clip_norm) # 梯度裁剪
meta_grad = {}
for name, params in zip(self.model.state_dict(), self.model.parameters()):
if params.grad is not None:
meta_grad[name] = torch.mean(params.grad.data) # 对梯度进行平均
for name in param_dict:
if name in meta_grad:
param_dict[name] = param_dict[name] + meta_grad[name] * torch.ones_like(
param_dict[name]) # 标量梯度乘以全1张量并累加
net_params = self.model.state_dict()
net_params_new = {name: net_params[name] + self.beta * param_dict[name] / num_tasks for name in net_params}
self.model.load_state_dict(net_params_new)
def train_maml_vae(mlma_model, general_loader, specific_loader, num_tasks, num_inner_steps, num_samples_per_task,
meta_iteration):
global device
torch.cuda.empty_cache()
print(f"Meta Iteration: {meta_iteration}")
for inner_step in range(num_inner_steps):
mlma_model.meta_update(num_tasks, general_loader, specific_loader, num_samples_per_task)
print(f"\rInner Step: {inner_step + 1}/{num_inner_steps}")
def compute_task_loss(x_recon, x, z_mu, z_var):
# 重建损失
recon_loss = F.mse_loss(x_recon, x, reduction='sum')
# KL散度损失
kl_loss = -0.5 * torch.sum(1 + z_var - z_mu.pow(2) - z_var.exp())
# 总损失
task_loss = recon_loss + kl_loss
return task_loss
def sample_task_data(general_loader, specific_loader, num_samples):
"""
Sample task-specific data from the general and specific data loaders.
"""
general_data = next(iter(general_loader))
specific_data = next(iter(specific_loader))
general_indices = torch.randint(0, len(general_data), (num_samples,))
specific_indices = torch.randint(0, len(specific_data), (num_samples,))
task_data = torch.cat((general_data[general_indices], specific_data[specific_indices]), dim=0)
global device
return task_data.to(device)
def seg_tensor_dataset(tensor_dataset):
part1, part2 = [], []
for index, tensor in enumerate(tensor_dataset):
image = transforms.ToPILImage()(tensor)
if index < 5:
part1.append(transforms.ToTensor()(image))
else:
part2.append(transforms.ToTensor()(image))
return torch.stack(part1), torch.stack(part2)
def merge_datasets(tensor_dataset_list: list):
merged_dataset = []
for dataset in tensor_dataset_list:
merged_dataset.extend(dataset)
return torch.stack(merged_dataset)
def load_datasets(zoom_factor=1.2):
image_directory = "./cut_imgs/"
# 读取图片并创建数据集
loaded_datasets = [load_image_datasets(image_directory)]
# 合并训练数据集
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
general_data = augment_tensor_dataset(combined_dataset)
# 加载测试数据集
test_dataset_path = './datasets/test_dataset.pt'
if os.path.isfile(test_dataset_path):
specific_data = torch.load(test_dataset_path)
print(f"Test dataset has been loaded and merged from {test_dataset_path}.")
augmented_specific_data = []
for image_tensor in specific_data:
augmented_tensor = do_nothing(image_tensor, zoom_factor)
augmented_specific_data.append(augmented_tensor)
specific_data = augmented_specific_data
else:
print("Test dataset files not found.")
specific_data = None
return general_data, specific_data
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# general_data: 通用训练材料
# specific_data: 指定任务的训练素材
# 在主函数中调用 load_datasets 函数
general_data, specific_data = load_datasets()
general_dataset = TensorDataset(general_data)
specific_dataset = TensorDataset(specific_data)
general_loader = DataLoader(general_dataset, batch_size=5, shuffle=True)
specific_loader = DataLoader(specific_dataset, batch_size=5, shuffle=True)
# 定义MAML模型
maml_model = MAML(1e-5, 1e-8, device)
model_path = './saved_model/VAE_cold_start.pth'
model_dir = os.path.dirname(model_path)
print('Try to load model from', model_path)
# 检查模型文件夹路径是否存在
if not os.path.exists(model_dir):
# 不存在就创建新的目录
os.makedirs(model_dir)
print(f"Created directory '{model_dir}' for saving models.")
if os.path.isfile(model_path):
try:
maml_model.model.load_state_dict(torch.load(model_path, map_location=device))
print("Model loaded successfully from '{}'".format(model_path))
except Exception as e:
print("Failed to load model. Starting from scratch. Error: ", e)
else:
print("No saved model found at '{}'. Starting from scratch.".format(model_path))
# 定义训练循环需要的变量
num_meta_iterations = 100 # 元迭代次数
num_tasks = 5 # 任务数量
num_inner_steps = 5 # 每个任务的内部更新步数
num_samples_per_task = 10 # 每个任务采样的样本数
# 训练循环
for meta_iteration in range(num_meta_iterations):
print('-' * 10, 'Meta Iteration', meta_iteration, '-' * 10)
# 调用train_maml_vae函数进行训练
train_maml_vae(maml_model, general_loader, specific_loader, num_tasks, num_inner_steps, num_samples_per_task,
meta_iteration)
# 保存当前的模型参数
vae_model_state_dict = maml_model.model.state_dict()
torch.save(vae_model_state_dict, './saved_model/VAE_cold_start.pth')