diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000..0a05f86
Binary files /dev/null and b/.DS_Store differ
diff --git a/.idea/code.iml b/.idea/code.iml
new file mode 100644
index 0000000..e033a9d
--- /dev/null
+++ b/.idea/code.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..dc0b21b
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..78eae18
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
new file mode 100644
index 0000000..2248ae7
--- /dev/null
+++ b/.idea/workspace.xml
@@ -0,0 +1,161 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1641808217381
+
+
+ 1641808217381
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..f3edfe5
--- /dev/null
+++ b/README.md
@@ -0,0 +1,53 @@
+# Introduction
+
+This is the implementation of the paper [SMoA: Searching a Modality-Oriented Architecture for Infrared and Visible Image Fusion](https://ieeexplore.ieee.org/abstract/document/9528046).
+
+## Requirements
+
+* python >= 3.6
+* pytorch == 1.7
+* torchvision == 0.8
+
+## Datasets
+
+You can download the datasets [here](https://pan.baidu.com/s/1kUja4iau37MwLnGI8_lMWg?pwd=eapv).
+
+## Test
+
+```shell
+python test.py
+```
+
+## Train from scratch
+
+### step 1
+
+```shell
+python train_search.py
+```
+
+### step 2
+
+Find the string which descripting the searched architectures in the log file. Copy and paste it into the genotypes.py, the format should consist with the primary architecture string.
+
+### step 3
+
+```shell
+python train.py
+```
+
+## Citation
+
+If you use any part of this code in your research, please cite our [paper](https://ieeexplore.ieee.org/abstract/document/9528046):
+
+```
+@ARTICLE{9528046,
+author={Liu, Jinyuan and Wu, Yuhui and Huang, Zhanbo and Liu, Risheng and Fan, Xin},
+journal={IEEE Signal Processing Letters},
+title={SMoA: Searching a Modality-Oriented Architecture for Infrared and Visible Image Fusion},
+year={2021},
+volume={28},
+number={},
+pages={1818-1822},
+doi={10.1109/LSP.2021.3109818}}
+```
diff --git a/Strategy_vsm/A_001.jpg b/Strategy_vsm/A_001.jpg
new file mode 100644
index 0000000..653e225
Binary files /dev/null and b/Strategy_vsm/A_001.jpg differ
diff --git a/Strategy_vsm/A_002.jpg b/Strategy_vsm/A_002.jpg
new file mode 100644
index 0000000..40ee477
Binary files /dev/null and b/Strategy_vsm/A_002.jpg differ
diff --git a/Strategy_vsm/A_003.jpg b/Strategy_vsm/A_003.jpg
new file mode 100644
index 0000000..d44fb6b
Binary files /dev/null and b/Strategy_vsm/A_003.jpg differ
diff --git a/Strategy_vsm/A_004.jpg b/Strategy_vsm/A_004.jpg
new file mode 100644
index 0000000..78c197e
Binary files /dev/null and b/Strategy_vsm/A_004.jpg differ
diff --git a/Strategy_vsm/A_005.jpg b/Strategy_vsm/A_005.jpg
new file mode 100644
index 0000000..df54de5
Binary files /dev/null and b/Strategy_vsm/A_005.jpg differ
diff --git a/Strategy_vsm/A_006.jpg b/Strategy_vsm/A_006.jpg
new file mode 100644
index 0000000..9087c88
Binary files /dev/null and b/Strategy_vsm/A_006.jpg differ
diff --git a/Strategy_vsm/A_007.jpg b/Strategy_vsm/A_007.jpg
new file mode 100644
index 0000000..94859dc
Binary files /dev/null and b/Strategy_vsm/A_007.jpg differ
diff --git a/Strategy_vsm/A_008.jpg b/Strategy_vsm/A_008.jpg
new file mode 100644
index 0000000..4b58436
Binary files /dev/null and b/Strategy_vsm/A_008.jpg differ
diff --git a/Strategy_vsm/A_009.jpg b/Strategy_vsm/A_009.jpg
new file mode 100644
index 0000000..ede7f0e
Binary files /dev/null and b/Strategy_vsm/A_009.jpg differ
diff --git a/Strategy_vsm/A_010.jpg b/Strategy_vsm/A_010.jpg
new file mode 100644
index 0000000..2edd60f
Binary files /dev/null and b/Strategy_vsm/A_010.jpg differ
diff --git a/Strategy_vsm/A_011.jpg b/Strategy_vsm/A_011.jpg
new file mode 100644
index 0000000..2d2fce7
Binary files /dev/null and b/Strategy_vsm/A_011.jpg differ
diff --git a/Strategy_vsm/A_012.jpg b/Strategy_vsm/A_012.jpg
new file mode 100644
index 0000000..46ce7da
Binary files /dev/null and b/Strategy_vsm/A_012.jpg differ
diff --git a/Strategy_vsm/A_013.jpg b/Strategy_vsm/A_013.jpg
new file mode 100644
index 0000000..76c6882
Binary files /dev/null and b/Strategy_vsm/A_013.jpg differ
diff --git a/Strategy_vsm/A_014.jpg b/Strategy_vsm/A_014.jpg
new file mode 100644
index 0000000..9a81308
Binary files /dev/null and b/Strategy_vsm/A_014.jpg differ
diff --git a/Strategy_vsm/A_015.jpg b/Strategy_vsm/A_015.jpg
new file mode 100644
index 0000000..db165ab
Binary files /dev/null and b/Strategy_vsm/A_015.jpg differ
diff --git a/Strategy_vsm/A_016.jpg b/Strategy_vsm/A_016.jpg
new file mode 100644
index 0000000..86d0444
Binary files /dev/null and b/Strategy_vsm/A_016.jpg differ
diff --git a/Strategy_vsm/A_017.jpg b/Strategy_vsm/A_017.jpg
new file mode 100644
index 0000000..8a261da
Binary files /dev/null and b/Strategy_vsm/A_017.jpg differ
diff --git a/Strategy_vsm/A_018.jpg b/Strategy_vsm/A_018.jpg
new file mode 100644
index 0000000..2255083
Binary files /dev/null and b/Strategy_vsm/A_018.jpg differ
diff --git a/Strategy_vsm/A_019.jpg b/Strategy_vsm/A_019.jpg
new file mode 100644
index 0000000..666cf30
Binary files /dev/null and b/Strategy_vsm/A_019.jpg differ
diff --git a/Strategy_vsm/A_020.jpg b/Strategy_vsm/A_020.jpg
new file mode 100644
index 0000000..96dcd33
Binary files /dev/null and b/Strategy_vsm/A_020.jpg differ
diff --git a/Strategy_vsm/A_021.jpg b/Strategy_vsm/A_021.jpg
new file mode 100644
index 0000000..d3de9fb
Binary files /dev/null and b/Strategy_vsm/A_021.jpg differ
diff --git a/Strategy_vsm/A_022.jpg b/Strategy_vsm/A_022.jpg
new file mode 100644
index 0000000..da214e2
Binary files /dev/null and b/Strategy_vsm/A_022.jpg differ
diff --git a/Strategy_vsm/A_023.jpg b/Strategy_vsm/A_023.jpg
new file mode 100644
index 0000000..1f2116b
Binary files /dev/null and b/Strategy_vsm/A_023.jpg differ
diff --git a/Strategy_vsm/A_024.jpg b/Strategy_vsm/A_024.jpg
new file mode 100644
index 0000000..6d944a5
Binary files /dev/null and b/Strategy_vsm/A_024.jpg differ
diff --git a/Strategy_vsm/A_025.jpg b/Strategy_vsm/A_025.jpg
new file mode 100644
index 0000000..af6b658
Binary files /dev/null and b/Strategy_vsm/A_025.jpg differ
diff --git a/Strategy_vsm/A_026.jpg b/Strategy_vsm/A_026.jpg
new file mode 100644
index 0000000..f95ed84
Binary files /dev/null and b/Strategy_vsm/A_026.jpg differ
diff --git a/Strategy_vsm/A_027.jpg b/Strategy_vsm/A_027.jpg
new file mode 100644
index 0000000..f944cd7
Binary files /dev/null and b/Strategy_vsm/A_027.jpg differ
diff --git a/Strategy_vsm/A_028.jpg b/Strategy_vsm/A_028.jpg
new file mode 100644
index 0000000..00a1f88
Binary files /dev/null and b/Strategy_vsm/A_028.jpg differ
diff --git a/Strategy_vsm/A_029.jpg b/Strategy_vsm/A_029.jpg
new file mode 100644
index 0000000..ed5999a
Binary files /dev/null and b/Strategy_vsm/A_029.jpg differ
diff --git a/Strategy_vsm/A_030.jpg b/Strategy_vsm/A_030.jpg
new file mode 100644
index 0000000..d82e011
Binary files /dev/null and b/Strategy_vsm/A_030.jpg differ
diff --git a/Strategy_vsm/A_031.jpg b/Strategy_vsm/A_031.jpg
new file mode 100644
index 0000000..fd7690d
Binary files /dev/null and b/Strategy_vsm/A_031.jpg differ
diff --git a/Strategy_vsm/A_032.jpg b/Strategy_vsm/A_032.jpg
new file mode 100644
index 0000000..af8a5b9
Binary files /dev/null and b/Strategy_vsm/A_032.jpg differ
diff --git a/Strategy_vsm/A_033.jpg b/Strategy_vsm/A_033.jpg
new file mode 100644
index 0000000..73da994
Binary files /dev/null and b/Strategy_vsm/A_033.jpg differ
diff --git a/Strategy_vsm/A_034.jpg b/Strategy_vsm/A_034.jpg
new file mode 100644
index 0000000..bbb1702
Binary files /dev/null and b/Strategy_vsm/A_034.jpg differ
diff --git a/Strategy_vsm/A_035.jpg b/Strategy_vsm/A_035.jpg
new file mode 100644
index 0000000..8b81dc3
Binary files /dev/null and b/Strategy_vsm/A_035.jpg differ
diff --git a/Strategy_vsm/A_036.jpg b/Strategy_vsm/A_036.jpg
new file mode 100644
index 0000000..9006b0c
Binary files /dev/null and b/Strategy_vsm/A_036.jpg differ
diff --git a/Strategy_vsm/A_037.jpg b/Strategy_vsm/A_037.jpg
new file mode 100644
index 0000000..65c67af
Binary files /dev/null and b/Strategy_vsm/A_037.jpg differ
diff --git a/__pycache__/genotypes.cpython-36.pyc b/__pycache__/genotypes.cpython-36.pyc
new file mode 100644
index 0000000..737a1d2
Binary files /dev/null and b/__pycache__/genotypes.cpython-36.pyc differ
diff --git a/__pycache__/model.cpython-36.pyc b/__pycache__/model.cpython-36.pyc
new file mode 100644
index 0000000..5693519
Binary files /dev/null and b/__pycache__/model.cpython-36.pyc differ
diff --git a/__pycache__/operations.cpython-36.pyc b/__pycache__/operations.cpython-36.pyc
new file mode 100644
index 0000000..808bd3f
Binary files /dev/null and b/__pycache__/operations.cpython-36.pyc differ
diff --git a/__pycache__/utils.cpython-36.pyc b/__pycache__/utils.cpython-36.pyc
new file mode 100644
index 0000000..b347d87
Binary files /dev/null and b/__pycache__/utils.cpython-36.pyc differ
diff --git a/architect.py b/architect.py
new file mode 100644
index 0000000..100b86f
--- /dev/null
+++ b/architect.py
@@ -0,0 +1,103 @@
+import torch
+import numpy as np
+import torch.nn as nn
+from torch.autograd import Variable
+
+
+def _concat(xs):
+ return torch.cat([x.view(-1) for x in xs])
+
+
+class Architect(object):
+
+ def __init__(self, model_former, model_latter, args, mse_loss, ssim_loss):
+ self.network_momentum = args.momentum
+ self.network_weight_decay = args.weight_decay
+ self.model_former = model_former
+ self.model_latter = model_latter
+ self.mse_loss = mse_loss
+ self.ssim_loss = ssim_loss
+ para = [{'params': model_former.arch_parameters(), 'lr': args.arch_learning_rate},
+ {'params': model_latter.arch_parameters(), 'lr': args.arch_learning_rate}]
+ self.optimizer = torch.optim.Adam(para,
+ lr=args.arch_learning_rate, betas=(0.5, 0.999),
+ weight_decay=args.arch_weight_decay)
+
+ def _compute_unrolled_model(self, input, eta, network_optimizer):
+ loss = self.model._loss(input)
+ theta = _concat(self.model.parameters()).data
+ try:
+ moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(
+ self.network_momentum)
+ except:
+ moment = torch.zeros_like(theta)
+ dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay * theta
+ unrolled_model = self._construct_model_from_theta(theta.sub(moment + dtheta, alpha=eta[0]))
+ return unrolled_model
+
+ def step(self, input_train, input_valid, eta, network_optimizer, unrolled):
+ self.optimizer.zero_grad()
+ if unrolled:
+ self._backward_step_unrolled(input_train, input_valid, eta, network_optimizer)
+ else:
+ self._backward_step(input_valid)
+ self.optimizer.step()
+
+ def _backward_step(self, input_valid):
+ en1, en2 = self.model_former(input_valid) # ######
+ output_valid = self.model_latter(en1, en2)
+ ssim_loss_value = 0.
+ pixel_loss_value = 0.
+ for output, input in zip(output_valid, input_valid):
+ output, input = torch.unsqueeze(output, 0), torch.unsqueeze(input, 0)
+ pixel_loss_temp = self.mse_loss(input, output)
+ ssim_loss_temp = self.ssim_loss(input, output, normalize=True, val_range=255)
+ ssim_loss_value += (1 - ssim_loss_temp)
+ pixel_loss_value += pixel_loss_temp
+ ssim_loss_value /= len(output_valid)
+ pixel_loss_value /= len(output_valid)
+
+ total_loss = pixel_loss_value + 100*ssim_loss_value # 加权?
+ total_loss.backward()
+
+ def _backward_step_unrolled(self, input_train, input_valid, eta, network_optimizer):
+ unrolled_model = self._compute_unrolled_model(input_train, eta, network_optimizer)
+ unrolled_loss = unrolled_model._loss(input_valid)
+ unrolled_loss.backward()
+ dalpha = [v.grad for v in unrolled_model.arch_parameters()]
+ vector = [v.grad.data for v in unrolled_model.parameters()]
+ implicit_grads = self._hessian_vector_product(vector, input_train)
+ for g, ig in zip(dalpha, implicit_grads):
+ g.data.sub_(ig.data, alpha=eta[0])
+ for v, g in zip(self.model.arch_parameters(), dalpha):
+ if v.grad is None:
+ v.grad = Variable(g.data)
+ else:
+ v.grad.data.copy_(g.data)
+
+ def _construct_model_from_theta(self, theta):
+ model_new = self.model.new()
+ model_dict = self.model.state_dict()
+ params, offset = {}, 0
+ for k, v in self.model.named_parameters():
+ v_length = np.prod(v.size())
+ params[k] = theta[offset: offset + v_length].view(v.size())
+ offset += v_length
+ assert offset == len(theta)
+ model_dict.update(params)
+ model_new.load_state_dict(model_dict)
+ return model_new.cuda()
+
+ def _hessian_vector_product(self, vector, input, r=1e-2):
+ R = r / _concat(vector).norm()
+ for p, v in zip(self.model.parameters(), vector):
+ p.data.add_(v, alpha=R)
+ loss = self.model._loss(input)
+ grads_p = torch.autograd.grad(loss, self.model.arch_parameters())
+ for p, v in zip(self.model.parameters(), vector):
+ p.data.sub_(v, alpha=2 * R)
+ loss = self.model._loss(input)
+ grads_n = torch.autograd.grad(loss, self.model.arch_parameters())
+ for p, v in zip(self.model.parameters(), vector):
+ p.data.add_(v, alpha=R)
+ return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)]
diff --git a/genotypes.py b/genotypes.py
new file mode 100644
index 0000000..ecd53d5
--- /dev/null
+++ b/genotypes.py
@@ -0,0 +1,19 @@
+from collections import namedtuple
+
+Genotype = namedtuple('Genotype', 'cell cell_concat')
+
+PRIMITIVES = [
+ 'none',
+ 'sep_conv_3x3',
+ 'sep_conv_5x5',
+ 'dil_conv_3x3',
+ 'dil_conv_5x5',
+ 'Spatialattention',
+ 'Denseblocks',
+ 'Residualblocks'
+]
+
+# epoch 20
+genotype_en1 = Genotype(cell=[('sep_conv_3x3', 0), ('Spatialattention', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_5x5', 0), ('sep_conv_5x5', 3), ('Denseblocks', 4), ('Denseblocks', 3)], cell_concat=range(2, 6))
+genotype_en2 = Genotype(cell=[('Spatialattention', 1), ('dil_conv_3x3', 0), ('dil_conv_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_3x3', 0), ('Spatialattention', 2), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], cell_concat=range(2, 6))
+genotype_de = Genotype(cell=[('sep_conv_3x3', 0), ('Denseblocks', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 0), ('sep_conv_5x5', 2), ('sep_conv_3x3', 0), ('Denseblocks', 4), ('sep_conv_3x3', 0)], cell_concat=range(2, 6))
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..484eb18
--- /dev/null
+++ b/model.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from operations import *
+from torch.autograd import Variable
+import torch.nn.functional as F
+
+
+class Cell(nn.Module):
+
+ def __init__(self, genotype, C_prev_prev, C_prev, C):
+ super(Cell, self).__init__()
+ print(C_prev_prev, C_prev, C)
+
+ self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
+ self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
+ op_names, indices = zip(*genotype.cell)
+ concat = genotype.cell_concat
+ self._compile(C, op_names, indices, concat)
+
+ def _compile(self, C, op_names, indices, concat):
+ assert len(op_names) == len(indices)
+ self._steps = len(op_names) // 2
+ self._concat = concat
+ self.multiplier = len(concat)
+ self._ops = nn.ModuleList()
+ for name, index in zip(op_names, indices):
+ stride = 1
+ op = OPS[name](C, stride, True)
+ self._ops += [op]
+ self._indices = indices
+
+ def forward(self, s0, s1):
+ s0 = self.preprocess0(s0)
+ s1 = self.preprocess1(s1)
+ states = [s0, s1]
+ for i in range(self._steps):
+ h1 = states[self._indices[2 * i]]
+ h2 = states[self._indices[2 * i + 1]]
+ op1 = self._ops[2 * i]
+ op2 = self._ops[2 * i + 1]
+ h1 = op1(h1)
+ h2 = op2(h2)
+ s = h1 + h2
+ states += [s]
+ return torch.cat([states[i] for i in self._concat], dim=1)
+
+
+
+
+class Encoder(nn.Module):
+
+ def __init__(self, C, layers, genotype):
+ super(Encoder, self).__init__()
+ self._inC = C # 4
+ self._layers = layers # 3
+ C_curr = 8
+
+ self.stem = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(1, 8, 3, padding=0, bias=False),
+ # nn.BatchNorm2d(8)
+ )
+
+ C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
+ self.cells = nn.ModuleList()
+ for i in range(layers):
+ cell = Cell(genotype, C_prev_prev, C_prev, C_curr)
+ self.cells += [cell]
+ C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
+
+ def forward(self, input):
+ s0 = s1 = self.stem(input)
+ for i, cell in enumerate(self.cells):
+ s0, s1 = s1, cell(s0, s1)
+ return s0, s1
+
+
+class Decoder(nn.Module):
+
+ def __init__(self, C, layers, genotype):
+ super(Decoder, self).__init__()
+ self._inC = C # 8
+ self._layers = layers # 2
+ C_prev_prev, C_prev, C_curr = C*4, C*4, C
+ self.cells = nn.ModuleList()
+ for i in range(layers):
+ cell = Cell(genotype, C_prev_prev, C_prev, C_curr)
+ self.cells += [cell]
+ C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
+ self.pad = nn.ReflectionPad2d(1)
+ self.ConvLayer = nn.Conv2d(C_curr*4, 1, 3, padding=0)
+
+ def forward(self, s0, s1):
+ for i, cell in enumerate(self.cells):
+ s0, s1 = s1, cell(s0, s1)
+ output = self.pad(s1)
+ output = self.ConvLayer(output)
+ return output
+
+
+
diff --git a/model_search.py b/model_search.py
new file mode 100644
index 0000000..3ea1548
--- /dev/null
+++ b/model_search.py
@@ -0,0 +1,213 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from operations import *
+from torch.autograd import Variable
+from genotypes import PRIMITIVES
+from genotypes import Genotype
+
+
+class MixedOp(nn.Module):
+
+ def __init__(self, C, stride):
+ super(MixedOp, self).__init__()
+ self._ops = nn.ModuleList()
+ for primitive in PRIMITIVES: # 8
+ op = OPS[primitive](C, stride, False)
+ self._ops.append(op)
+
+ def forward(self, x, weights):
+ return sum(w * op(x) for w, op in zip(weights, self._ops))
+
+
+class Cell(nn.Module):
+
+ def __init__(self, steps, multiplier, C_prev_prev, C_prev, C):
+ super(Cell, self).__init__()
+ print(C_prev_prev, C_prev, C)
+ self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
+ self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
+ self._steps = steps
+ self._multiplier = multiplier
+ self._ops = nn.ModuleList()
+ self._bns = nn.ModuleList()
+ for i in range(self._steps): # 4个中间节点
+ for j in range(2 + i):
+ stride = 1
+ op = MixedOp(C, stride)
+ self._ops.append(op) # 14个平均操作
+
+ def forward(self, s0, s1, weights):
+ s0 = self.preprocess0(s0)
+ s1 = self.preprocess1(s1)
+ states = [s0, s1]
+ offset = 0
+ for i in range(self._steps): # 对于每一个中间节点
+ s = sum(self._ops[offset + j](h, weights[offset + j]) for j, h in enumerate(states)) # 每个节点的多个平均操作求和,得到该点的输出
+ offset += len(states)
+ states.append(s)
+ return torch.cat(states[-self._multiplier:], dim=1) # 合并4个节点的输出
+
+
+class Encoder(nn.Module):
+
+ def __init__(self, C, layers, steps=4, multiplier=4):
+ super(Encoder, self).__init__()
+ self._inC = C # 4
+ self._layers = layers # 3
+ self._steps = steps
+ self._multiplier = multiplier
+ C_curr = 8
+
+ self.stem = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(1, 8, 3, padding=0, bias=False),
+ # nn.BatchNorm2d(8)
+ )
+
+ C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
+ self.cells = nn.ModuleList()
+ for i in range(layers):
+ # C_curr = C*(2**i)
+ cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr)
+ self.cells += [cell]
+ C_prev_prev, C_prev = C_prev, multiplier * C_curr
+
+ self._initialize_alphas()
+
+ def new(self):
+ model_new = Encoder(self._inC, self._layers).cuda()
+ for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
+ x.data.copy_(y.data)
+ return model_new
+
+ def forward(self, input):
+ s0 = s1 = self.stem(input)
+ for i, cell in enumerate(self.cells):
+ weights = F.softmax(self.alphas, dim=-1)
+ s0, s1 = s1, cell(s0, s1, weights)
+ return s0, s1
+
+ def _initialize_alphas(self):
+ k = sum(1 for i in range(self._steps) for n in range(2 + i)) # 14
+ num_ops = len(PRIMITIVES)
+
+ self.alphas = Variable(1e-3 * torch.randn((k, num_ops))).cuda()
+ self.alphas.requires_grad = True
+
+ self._arch_parameters = [
+ self.alphas
+ ]
+
+ def arch_parameters(self):
+ return self._arch_parameters
+
+ def genotype(self):
+ def _parse(weights):
+ gene = []
+ n = 2
+ start = 0
+ for i in range(self._steps):
+ end = start + n
+ W = weights[start:end].copy()
+ edges = sorted(range(i + 2),
+ key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[
+ :2]
+ for j in edges:
+ k_best = None
+ for k in range(len(W[j])):
+ if k != PRIMITIVES.index('none'):
+ if k_best is None or W[j][k] > W[j][k_best]:
+ k_best = k
+ gene.append((PRIMITIVES[k_best], j))
+ start = end
+ n += 1
+ return gene
+
+ gene_former = _parse(F.softmax(self.alphas, dim=-1).data.cpu().numpy())
+ concat = range(2 + self._steps - self._multiplier, self._steps + 2)
+ genotype = Genotype(
+ cell=gene_former, cell_concat=concat
+ )
+ return genotype
+
+
+class Decoder(nn.Module):
+
+ def __init__(self, C, layers, steps=4, multiplier=4):
+ super(Decoder, self).__init__()
+ self._inC = C # 8
+ self._layers = layers # 2
+ self._steps = steps
+ self._multiplier = multiplier
+
+ C_prev_prev, C_prev, C_curr = C*4, C*4, C
+ self.cells = nn.ModuleList()
+ for i in range(layers):
+ # C_curr = C//(2**i)
+ cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr)
+ self.cells += [cell]
+ C_prev_prev, C_prev = C_prev, multiplier * C_curr
+ self.pad = nn.ReflectionPad2d(1)
+ self.ConvLayer = nn.Conv2d(C_curr*multiplier, 1, 3, padding=0)
+ # self.tanh = nn.Tanh()
+ self._initialize_alphas()
+
+ def new(self):
+ model_new = Decoder(self._inC, self._layers).cuda()
+ for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
+ x.data.copy_(y.data)
+ return model_new
+
+ def forward(self, s0, s1):
+ for i, cell in enumerate(self.cells):
+ weights = F.softmax(self.alphas, dim=-1)
+ s0, s1 = s1, cell(s0, s1, weights)
+ output = self.pad(s1)
+ output = self.ConvLayer(output)
+ return output
+
+ def _initialize_alphas(self):
+ k = sum(1 for i in range(self._steps) for n in range(2 + i)) # 14
+ num_ops = len(PRIMITIVES)
+
+ self.alphas = Variable(1e-3 * torch.randn((k, num_ops))).cuda()
+ self.alphas.requires_grad = True
+
+ self._arch_parameters = [
+ self.alphas
+ ]
+
+ def arch_parameters(self):
+ return self._arch_parameters
+
+ def genotype(self):
+ def _parse(weights):
+ gene = []
+ n = 2
+ start = 0
+ for i in range(self._steps):
+ end = start + n
+ W = weights[start:end].copy()
+ edges = sorted(range(i + 2),
+ key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[
+ :2]
+ for j in edges:
+ k_best = None
+ for k in range(len(W[j])):
+ if k != PRIMITIVES.index('none'):
+ if k_best is None or W[j][k] > W[j][k_best]:
+ k_best = k
+ gene.append((PRIMITIVES[k_best], j))
+ start = end
+ n += 1
+ return gene
+
+ gene = _parse(F.softmax(self.alphas, dim=-1).data.cpu().numpy())
+ concat = range(2 + self._steps - self._multiplier, self._steps + 2)
+ genotype = Genotype(
+ cell=gene, cell_concat=concat
+ )
+ return genotype
+
+
diff --git a/operations.py b/operations.py
new file mode 100644
index 0000000..99d91b9
--- /dev/null
+++ b/operations.py
@@ -0,0 +1,286 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+OPS = {
+ 'none': lambda C, stride, affine: Zero(stride),
+ 'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
+ 'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
+ 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 2,2
+ 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 4,2
+ 'NonLocalattention': lambda C, stride, affine: NLBasicBlock(C, stride),
+ 'Spatialattention': lambda C, stride, affine: Spatial_BasicBlock(C, stride),
+ 'Denseblocks': lambda C, stride, affine: ResidualDenseBlock(C, stride),
+ 'Residualblocks': lambda C, stride, affine: ResidualModule(C, stride),
+}
+
+
+class BasicConv(nn.Module):
+ def __init__(self, in_planes, out_planes, stride=1, dilation=1, groups=1, relu=True, bn=False,
+ bias=False):
+ super(BasicConv, self).__init__()
+ # judge
+ # stride = 1
+ padding = 0
+ kernel_size = 3
+ if kernel_size == 3 and dilation == 1:
+ padding = 1
+ if kernel_size == 3 and dilation == 2:
+ padding = 2
+ if kernel_size == 5 and dilation == 1:
+ padding = 2
+ if kernel_size == 5 and dilation == 2:
+ padding = 4
+ if kernel_size == 7 and dilation == 1:
+ padding = 3
+ if kernel_size == 7 and dilation == 2:
+ padding = 6
+ self.out_channels = out_planes
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, groups=groups, bias=bias, padding_mode='reflect')
+ self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
+ self.relu = nn.PReLU() if relu else None
+
+ def forward(self, x):
+ x = self.conv(x)
+ if self.bn is not None:
+ x = self.bn(x)
+ if self.relu is not None:
+ x = self.relu(x)
+ return x
+
+
+class NonLocalBlock2D(nn.Module):
+ def __init__(self, in_channels, inter_channels, bias=False):
+ super(NonLocalBlock2D, self).__init__()
+
+ self.in_channels = in_channels
+ self.inter_channels = inter_channels
+
+ self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1,
+ padding=0, bias=bias)
+
+ self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1,
+ padding=0, bias=bias)
+ # for pytorch 0.3.1
+ # nn.init.constant(self.W.weight, 0)
+ # nn.init.constant(self.W.bias, 0)
+ # for pytorch 0.4.0
+ nn.init.constant_(self.W.weight, 0)
+ # nn.init.constant_(self.W.bias, 0)
+ self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1,
+ padding=0, bias=bias)
+
+ self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1,
+ padding=0, bias=bias)
+
+ def forward(self, x):
+ batch_size = x.size(0)
+
+ g_x = self.g(x).view(batch_size, self.inter_channels, -1)
+
+ g_x = g_x.permute(0, 2, 1)
+
+ theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
+
+ theta_x = theta_x.permute(0, 2, 1)
+
+ phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
+
+ f = torch.matmul(theta_x, phi_x)
+
+ f_div_C = F.softmax(f, dim=1)
+
+ y = torch.matmul(f_div_C, g_x)
+
+ y = y.permute(0, 2, 1).contiguous()
+
+ y = y.view(batch_size, self.inter_channels, *x.size()[2:])
+ W_y = self.W(y)
+ z = W_y + x
+ return z
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, padding_mode='reflect')
+
+
+class NLBasicBlock(nn.Module):
+ def __init__(self, inplanes, stride=1, with_norm=False):
+ super(NLBasicBlock, self).__init__()
+ self.with_norm = with_norm
+ kernel = 3
+ self.conv1 = conv3x3(inplanes, inplanes, stride)
+ self.conv2 = BasicConv(inplanes, inplanes, relu=False)
+ self.se = NonLocalBlock2D(inplanes, inplanes)
+ self.relu = nn.PReLU()
+ if self.with_norm:
+ self.bn1 = nn.BatchNorm2d(inplanes)
+ self.bn2 = nn.BatchNorm2d(inplanes)
+
+ def forward(self, x):
+ out = x = self.conv1(x)
+ if self.with_norm:
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.se(out)
+ out += x
+ out = self.conv2(out)
+ if self.with_norm:
+ out = self.bn2(out)
+ out = self.relu(out)
+ # print(out.shape)
+ return out
+
+
+class ChannelPool(nn.Module):
+ def forward(self, x):
+ return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
+
+
+class spatial_attn_layer(nn.Module):
+ def __init__(self, kernel_size=5):
+ super(spatial_attn_layer, self).__init__()
+ self.compress = ChannelPool()
+ self.spatial = BasicConv(2, 1, relu=False)
+
+ def forward(self, x):
+ # import pdb;pdb.set_trace()
+ x_compress = self.compress(x)
+ x_out = self.spatial(x_compress)
+ scale = torch.sigmoid(x_out) # broadcasting
+ return x * scale
+
+
+class Spatial_BasicBlock(nn.Module):
+ def __init__(self, inplanes, stride=1, reduction=64, with_norm=False):
+ super(Spatial_BasicBlock, self).__init__()
+ self.with_norm = with_norm
+ kernel = 3
+ self.conv1 = conv3x3(inplanes, inplanes, stride)
+ self.conv2 = BasicConv(inplanes, inplanes, relu=False)
+ self.se = spatial_attn_layer(kernel)
+ self.relu = nn.PReLU()
+ if self.with_norm:
+ self.bn1 = nn.BatchNorm2d(inplanes)
+ self.bn2 = nn.BatchNorm2d(inplanes)
+
+ def forward(self, x):
+ out = x = self.conv1(x)
+ if self.with_norm:
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ if self.with_norm:
+ out = self.bn2(out)
+ out = self.se(out)
+ out += x
+ out = self.relu(out)
+ return out
+
+class ResidualDenseBlock(nn.Module):
+ def __init__(self, in_channels, stride):
+ super(ResidualDenseBlock, self).__init__()
+ # gc: growth channel, i.e. intermediate channels
+
+ self.conv1 = BasicConv(in_channels, in_channels, stride, relu=False)
+ self.conv2 = BasicConv(in_channels * 2, in_channels, stride, relu=False)
+ self.conv3 = BasicConv(in_channels * 3, in_channels, stride, relu=False)
+
+ self.lrelu = nn.PReLU()
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ return x3 * 0.333333 + x
+
+
+class ResidualModule(nn.Module):
+ def __init__(self, in_channels, stride, dialtions=1):
+ super(ResidualModule, self).__init__()
+ self.op = nn.Sequential(
+ BasicConv(in_channels, in_channels, stride, dilation=dialtions, relu=False, groups=in_channels),
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=2, dilation=2, groups=in_channels,
+ bias=False, padding_mode='reflect'),
+ nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0, bias=False),
+ # nn.BatchNorm2d(in_channels),
+ nn.PReLU(),
+ )
+
+ def forward(self, x):
+ res = self.op(x)
+ return x + res
+
+
+class ReLUConvBN(nn.Module):
+
+ def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
+ super(ReLUConvBN, self).__init__()
+ self.op = nn.Sequential(
+ nn.ReLU(inplace=False),
+ nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False, padding_mode='reflect'),
+ # nn.BatchNorm2d(C_out, affine=affine)
+ )
+
+ def forward(self, x):
+ return self.op(x)
+
+
+class DilConv(nn.Module):
+
+ def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
+ super(DilConv, self).__init__()
+ self.op = nn.Sequential(
+ nn.ReLU(inplace=False),
+ nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
+ groups=C_in, bias=False, padding_mode='reflect'),
+ nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
+ # nn.BatchNorm2d(C_out, affine=affine),
+ )
+
+ def forward(self, x):
+ return self.op(x)
+
+# 深度可分离卷积
+
+class SepConv(nn.Module):
+
+ def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
+ super(SepConv, self).__init__()
+ self.op = nn.Sequential(
+ nn.ReLU(inplace=False),
+ nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False, padding_mode='reflect'),
+ nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
+ # nn.BatchNorm2d(C_in, affine=affine),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False, padding_mode='reflect'),
+ nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
+ # nn.BatchNorm2d(C_out, affine=affine),
+ )
+
+ def forward(self, x):
+ return self.op(x)
+
+
+class Identity(nn.Module):
+
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, x):
+ return x
+
+
+class Zero(nn.Module):
+
+ def __init__(self, stride):
+ super(Zero, self).__init__()
+ self.stride = stride
+
+ def forward(self, x):
+ if self.stride == 1:
+ return x.mul(0.)
+ return x[:, :, ::self.stride, ::self.stride].mul(0.)
+
diff --git a/pytorch_msssim/__init__.py b/pytorch_msssim/__init__.py
new file mode 100644
index 0000000..3aead62
--- /dev/null
+++ b/pytorch_msssim/__init__.py
@@ -0,0 +1,135 @@
+import torch
+import torch.nn.functional as F
+from math import exp
+import numpy as np
+
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
+ return gauss/gauss.sum()
+
+
+def create_window(window_size, channel=1):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
+ return window
+
+
+def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
+ if val_range is None:
+ if torch.max(img1) > 128:
+ max_val = 255
+ else:
+ max_val = 1
+
+ if torch.min(img1) < -0.5:
+ min_val = -1
+ else:
+ min_val = 0
+ L = max_val - min_val
+ else:
+ L = val_range
+
+ padd = 0
+ (_, channel, height, width) = img1.size()
+ if window is None:
+ real_size = min(window_size, height, width)
+ window = create_window(real_size, channel=channel).to(img1.device)
+
+ mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
+
+ C1 = (0.01 * L) ** 2
+ C2 = (0.03 * L) ** 2
+
+ v1 = 2.0 * sigma12 + C2
+ v2 = sigma1_sq + sigma2_sq + C2
+ cs = torch.mean(v1 / v2) # contrast sensitivity
+
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
+
+ if size_average:
+ ret = ssim_map.mean()
+ else:
+ ret = ssim_map.mean(1).mean(1).mean(1)
+
+ if full:
+ return ret, cs
+ return ret
+
+
+def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
+ device = img1.device
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
+ levels = weights.size()[0]
+ mssim = []
+ mcs = []
+ for _ in range(levels):
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
+ mssim.append(sim)
+ mcs.append(cs)
+
+ img1 = F.avg_pool2d(img1, (2, 2))
+ img2 = F.avg_pool2d(img2, (2, 2))
+
+ mssim = torch.stack(mssim)
+ mcs = torch.stack(mcs)
+ # print('mcs = torch.stack(mcs):', mcs)
+
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
+ if normalize:
+ mssim = (mssim + 1) / 2
+ mcs = (mcs + 1) / 2
+ # print('mssim:', mssim)
+ # print('mcs:', mcs)
+ pow1 = mcs ** weights
+ pow2 = mssim ** weights
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
+ output = torch.prod(pow1[:-1] * pow2[-1])
+ return output
+
+
+# Classes to re-use window
+class SSIM(torch.nn.Module):
+ def __init__(self, window_size=11, size_average=True, val_range=None):
+ super(SSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.val_range = val_range
+
+ # Assume 1 channel for SSIM
+ self.channel = 1
+ self.window = create_window(window_size)
+
+ def forward(self, img1, img2):
+ (_, channel, _, _) = img1.size()
+
+ if channel == self.channel and self.window.dtype == img1.dtype:
+ window = self.window
+ else:
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
+ self.window = window
+ self.channel = channel
+
+ return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
+
+class MSSSIM(torch.nn.Module):
+ def __init__(self, window_size=11, size_average=True, channel=3):
+ super(MSSSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.channel = channel
+
+ def forward(self, img1, img2):
+ # TODO: store window between calls if possible
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
diff --git a/pytorch_msssim/__pycache__/__init__.cpython-36.pyc b/pytorch_msssim/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..16134fb
Binary files /dev/null and b/pytorch_msssim/__pycache__/__init__.cpython-36.pyc differ
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..cd9d09e
--- /dev/null
+++ b/test.py
@@ -0,0 +1,139 @@
+import cv2
+import torch
+from model import Encoder, Decoder
+from os.path import join
+from os import listdir
+import PIL.Image as Image
+import numpy as np
+import torch.backends.cudnn as cudnn
+import torchvision.transforms as transforms
+import os
+from os.path import exists
+from utils import get_train_images_auto, get_test_images
+import genotypes
+import torch.nn.functional as F
+import glob
+
+tensor_to_pil = transforms.ToPILImage()
+model_dir = r'E:\已投论文\SPL\code\trainJT'
+encoder1_path1 = join(model_dir, 'encoder1_epoch6.pt')
+encoder2_path2 = join(model_dir, 'encoder2_epoch6.pt')
+decoder_path = join(model_dir, 'decoder_epoch6.pt')
+
+genotype_en1 = eval("genotypes.%s" % 'genotype_en1')
+genotype_en2 = eval("genotypes.%s" % 'genotype_en2')
+
+genotype2 = eval("genotypes.%s" % 'genotype_de')
+
+encoder1 = Encoder(16, 2, genotype_en1).cuda()
+encoder2 = Encoder(16, 2, genotype_en2).cuda()
+
+decoder = Decoder(16, 2, genotype2).cuda()
+
+
+params1 = torch.load(encoder1_path1)
+params2 = torch.load(encoder2_path2)
+params3 = torch.load(decoder_path)
+
+encoder1.load_state_dict(params1)
+encoder2.load_state_dict(params2)
+
+decoder.load_state_dict(params3)
+
+encoder1.eval()
+encoder2.eval()
+
+decoder.eval()
+
+c = 1e-2
+
+image_dir2 = r'C:\Users\ADMIN\Desktop\Test'
+
+par = os.getcwd()
+image_dir2 = par + '\\testImage'
+save_dir = par + '\\result'
+
+if not exists(save_dir):
+ os.mkdir(save_dir)
+
+def vsm2(tensor):
+ his = tensor.histc(bins=256, min=0, max=255)
+ sal = torch.zeros(256).to(torch.int64).cuda()
+ for i in range(256):
+ for j in range(256):
+ sal[i] += abs(j-i)*his[j]
+ sal = sal.div(sal.max())#.to(torch.float32)
+ map = torch.zeros_like(tensor).cuda().to(torch.float32)
+ for i in range(256):
+ map[tensor == i] = sal[i]
+ return map
+
+
+def vsf3(tensor1, tensor2):
+ t1 = (tensor1/tensor1.max()*255).to(torch.int)
+ t2 = (tensor2/tensor2.max()*255).to(torch.int)
+ weight1 = vsm2(t1)
+ weight2 = vsm2(t2)
+ F = (0.5 + 0.5 * (weight1 - weight2)) * tensor1 + (0.5 + 0.5 * (weight2 - weight1)) * tensor2
+ return F
+
+def fuse_L(ir_path, vis_path, save_dir, name):
+ image_ir_path = join(ir_path)
+ image_vis_path = join(vis_path)
+ tensor_ir = get_test_images(image_ir_path).cuda()
+ tensor_ir.requires_grad = False
+ tensor_vis = get_test_images(image_vis_path).cuda()
+ tensor_vis.requires_grad = False
+
+ en11, en12 = encoder1(tensor_ir)
+ en21, en22 = encoder2(tensor_vis)
+
+ en1 = vsf3(en11, en21)
+ en2 = vsf3(en12, en22)
+
+ tensor_f = decoder(en1, en2).cpu()
+ image_tensor = tensor_f.squeeze()
+
+ image_array = np.asarray(image_tensor.detach())
+ image_pil = Image.fromarray(image_array).convert('L')
+ image_pil.save(os.path.join(save_dir, name.split('.')[0] + '.jpg'))
+
+
+def fuse_RGB(ir_path, vis_path, save_dir, name):
+ ir = cv2.imread(ir_path)
+ vis = cv2.imread(vis_path)
+ vis_ycrcb = cv2.cvtColor(vis, cv2.COLOR_BGR2YCrCb)
+
+ tensor1 = torch.tensor(ir[:, :, 0], dtype=torch.float32).unsqueeze(0).unsqueeze(0).cuda()
+ tensor2 = torch.tensor(vis_ycrcb[:, :, 0], dtype=torch.float32).unsqueeze(0).unsqueeze(0).cuda()
+
+ en11, en12 = encoder1(tensor1)
+ en21, en22 = encoder2(tensor2)
+
+ en1 = vsf3(en11, en21)
+ en2 = vsf3(en12, en22)
+
+ tensor_f = decoder(en1, en2).cpu()
+
+ image_tensor = tensor_f.squeeze()
+ image_tensor = torch.clamp(image_tensor, 0, 255)
+ image_array = np.asarray(image_tensor.detach(), dtype=int)
+ re = np.stack([image_array, vis_ycrcb[:, :, 1], vis_ycrcb[:, :, 2]], axis=2).astype(np.uint8)
+ re = cv2.cvtColor(re, cv2.COLOR_YCrCb2BGR)
+ cv2.imwrite(os.path.join(save_dir, name.split('.')[0] + '_RGB.jpg'), re)
+ cv2.imwrite(os.path.join(save_dir, name.split('.')[0] + '_L.jpg'), image_array.astype(np.uint8))
+
+def test():
+ with torch.no_grad():
+ namelist = os.listdir(os.path.join(image_dir2, 'ir'))
+ for name in namelist:
+ ir_path = os.path.join(image_dir2, 'ir', name)
+ vis_path = os.path.join(image_dir2, 'vis', name)
+ if name.startswith('A'):
+ fuse_L(ir_path, vis_path, save_dir, name)
+ else:
+ fuse_RGB(ir_path, vis_path, save_dir, name)
+
+
+if __name__ == '__main__':
+ test()
\ No newline at end of file
diff --git a/testImage/.DS_Store b/testImage/.DS_Store
new file mode 100644
index 0000000..33c146f
Binary files /dev/null and b/testImage/.DS_Store differ
diff --git a/testImage/ir/A_001.bmp b/testImage/ir/A_001.bmp
new file mode 100644
index 0000000..de3ed00
Binary files /dev/null and b/testImage/ir/A_001.bmp differ
diff --git a/testImage/ir/A_002.bmp b/testImage/ir/A_002.bmp
new file mode 100644
index 0000000..f6ba32e
Binary files /dev/null and b/testImage/ir/A_002.bmp differ
diff --git a/testImage/ir/A_003.bmp b/testImage/ir/A_003.bmp
new file mode 100644
index 0000000..0c3aec8
Binary files /dev/null and b/testImage/ir/A_003.bmp differ
diff --git a/testImage/ir/A_004.bmp b/testImage/ir/A_004.bmp
new file mode 100644
index 0000000..2f18244
Binary files /dev/null and b/testImage/ir/A_004.bmp differ
diff --git a/testImage/ir/A_005.bmp b/testImage/ir/A_005.bmp
new file mode 100644
index 0000000..1d635c9
Binary files /dev/null and b/testImage/ir/A_005.bmp differ
diff --git a/testImage/ir/A_006.bmp b/testImage/ir/A_006.bmp
new file mode 100644
index 0000000..a6f38ee
Binary files /dev/null and b/testImage/ir/A_006.bmp differ
diff --git a/testImage/ir/A_007.bmp b/testImage/ir/A_007.bmp
new file mode 100644
index 0000000..241b236
Binary files /dev/null and b/testImage/ir/A_007.bmp differ
diff --git a/testImage/ir/A_008.bmp b/testImage/ir/A_008.bmp
new file mode 100644
index 0000000..41e0734
Binary files /dev/null and b/testImage/ir/A_008.bmp differ
diff --git a/testImage/ir/A_009.bmp b/testImage/ir/A_009.bmp
new file mode 100644
index 0000000..a796e61
Binary files /dev/null and b/testImage/ir/A_009.bmp differ
diff --git a/testImage/ir/A_010.bmp b/testImage/ir/A_010.bmp
new file mode 100644
index 0000000..0917ff9
Binary files /dev/null and b/testImage/ir/A_010.bmp differ
diff --git a/testImage/ir/A_011.bmp b/testImage/ir/A_011.bmp
new file mode 100644
index 0000000..1498305
Binary files /dev/null and b/testImage/ir/A_011.bmp differ
diff --git a/testImage/ir/A_012.bmp b/testImage/ir/A_012.bmp
new file mode 100644
index 0000000..4997ebd
Binary files /dev/null and b/testImage/ir/A_012.bmp differ
diff --git a/testImage/ir/A_013.bmp b/testImage/ir/A_013.bmp
new file mode 100644
index 0000000..cd7ef9d
Binary files /dev/null and b/testImage/ir/A_013.bmp differ
diff --git a/testImage/ir/A_014.bmp b/testImage/ir/A_014.bmp
new file mode 100644
index 0000000..81f773b
Binary files /dev/null and b/testImage/ir/A_014.bmp differ
diff --git a/testImage/ir/A_015.bmp b/testImage/ir/A_015.bmp
new file mode 100644
index 0000000..e21d308
Binary files /dev/null and b/testImage/ir/A_015.bmp differ
diff --git a/testImage/ir/A_016.bmp b/testImage/ir/A_016.bmp
new file mode 100644
index 0000000..7304b70
Binary files /dev/null and b/testImage/ir/A_016.bmp differ
diff --git a/testImage/ir/A_017.bmp b/testImage/ir/A_017.bmp
new file mode 100644
index 0000000..eafff51
Binary files /dev/null and b/testImage/ir/A_017.bmp differ
diff --git a/testImage/ir/A_018.bmp b/testImage/ir/A_018.bmp
new file mode 100644
index 0000000..e89d51e
Binary files /dev/null and b/testImage/ir/A_018.bmp differ
diff --git a/testImage/ir/A_019.bmp b/testImage/ir/A_019.bmp
new file mode 100644
index 0000000..c032f4f
Binary files /dev/null and b/testImage/ir/A_019.bmp differ
diff --git a/testImage/ir/A_020.bmp b/testImage/ir/A_020.bmp
new file mode 100644
index 0000000..6002d27
Binary files /dev/null and b/testImage/ir/A_020.bmp differ
diff --git a/testImage/ir/A_021.bmp b/testImage/ir/A_021.bmp
new file mode 100644
index 0000000..d8025b4
Binary files /dev/null and b/testImage/ir/A_021.bmp differ
diff --git a/testImage/ir/A_022.bmp b/testImage/ir/A_022.bmp
new file mode 100644
index 0000000..e09f761
Binary files /dev/null and b/testImage/ir/A_022.bmp differ
diff --git a/testImage/ir/A_023.bmp b/testImage/ir/A_023.bmp
new file mode 100644
index 0000000..324e599
Binary files /dev/null and b/testImage/ir/A_023.bmp differ
diff --git a/testImage/ir/A_024.bmp b/testImage/ir/A_024.bmp
new file mode 100644
index 0000000..3a42a9c
Binary files /dev/null and b/testImage/ir/A_024.bmp differ
diff --git a/testImage/ir/A_025.bmp b/testImage/ir/A_025.bmp
new file mode 100644
index 0000000..ce3732b
Binary files /dev/null and b/testImage/ir/A_025.bmp differ
diff --git a/testImage/ir/A_026.bmp b/testImage/ir/A_026.bmp
new file mode 100644
index 0000000..ff84cc2
Binary files /dev/null and b/testImage/ir/A_026.bmp differ
diff --git a/testImage/ir/A_027.bmp b/testImage/ir/A_027.bmp
new file mode 100644
index 0000000..6ed7e61
Binary files /dev/null and b/testImage/ir/A_027.bmp differ
diff --git a/testImage/ir/A_028.bmp b/testImage/ir/A_028.bmp
new file mode 100644
index 0000000..1f24cbc
Binary files /dev/null and b/testImage/ir/A_028.bmp differ
diff --git a/testImage/ir/A_029.bmp b/testImage/ir/A_029.bmp
new file mode 100644
index 0000000..9d6426a
Binary files /dev/null and b/testImage/ir/A_029.bmp differ
diff --git a/testImage/ir/A_030.bmp b/testImage/ir/A_030.bmp
new file mode 100644
index 0000000..3c0d4ed
Binary files /dev/null and b/testImage/ir/A_030.bmp differ
diff --git a/testImage/ir/A_031.bmp b/testImage/ir/A_031.bmp
new file mode 100644
index 0000000..71eb720
Binary files /dev/null and b/testImage/ir/A_031.bmp differ
diff --git a/testImage/ir/A_032.bmp b/testImage/ir/A_032.bmp
new file mode 100644
index 0000000..324e599
Binary files /dev/null and b/testImage/ir/A_032.bmp differ
diff --git a/testImage/ir/A_033.bmp b/testImage/ir/A_033.bmp
new file mode 100644
index 0000000..8102197
Binary files /dev/null and b/testImage/ir/A_033.bmp differ
diff --git a/testImage/ir/A_034.bmp b/testImage/ir/A_034.bmp
new file mode 100644
index 0000000..7f05d83
Binary files /dev/null and b/testImage/ir/A_034.bmp differ
diff --git a/testImage/ir/A_035.bmp b/testImage/ir/A_035.bmp
new file mode 100644
index 0000000..de51fdc
Binary files /dev/null and b/testImage/ir/A_035.bmp differ
diff --git a/testImage/ir/A_036.bmp b/testImage/ir/A_036.bmp
new file mode 100644
index 0000000..4145922
Binary files /dev/null and b/testImage/ir/A_036.bmp differ
diff --git a/testImage/ir/A_037.bmp b/testImage/ir/A_037.bmp
new file mode 100644
index 0000000..ec620a8
Binary files /dev/null and b/testImage/ir/A_037.bmp differ
diff --git a/testImage/vis/A_001.bmp b/testImage/vis/A_001.bmp
new file mode 100644
index 0000000..2f825ce
Binary files /dev/null and b/testImage/vis/A_001.bmp differ
diff --git a/testImage/vis/A_002.bmp b/testImage/vis/A_002.bmp
new file mode 100644
index 0000000..948f345
Binary files /dev/null and b/testImage/vis/A_002.bmp differ
diff --git a/testImage/vis/A_003.bmp b/testImage/vis/A_003.bmp
new file mode 100644
index 0000000..8cb060e
Binary files /dev/null and b/testImage/vis/A_003.bmp differ
diff --git a/testImage/vis/A_004.bmp b/testImage/vis/A_004.bmp
new file mode 100644
index 0000000..69bca33
Binary files /dev/null and b/testImage/vis/A_004.bmp differ
diff --git a/testImage/vis/A_005.bmp b/testImage/vis/A_005.bmp
new file mode 100644
index 0000000..84ba5a1
Binary files /dev/null and b/testImage/vis/A_005.bmp differ
diff --git a/testImage/vis/A_006.bmp b/testImage/vis/A_006.bmp
new file mode 100644
index 0000000..0388bb4
Binary files /dev/null and b/testImage/vis/A_006.bmp differ
diff --git a/testImage/vis/A_007.bmp b/testImage/vis/A_007.bmp
new file mode 100644
index 0000000..872a582
Binary files /dev/null and b/testImage/vis/A_007.bmp differ
diff --git a/testImage/vis/A_008.bmp b/testImage/vis/A_008.bmp
new file mode 100644
index 0000000..57c94c2
Binary files /dev/null and b/testImage/vis/A_008.bmp differ
diff --git a/testImage/vis/A_009.bmp b/testImage/vis/A_009.bmp
new file mode 100644
index 0000000..0db825e
Binary files /dev/null and b/testImage/vis/A_009.bmp differ
diff --git a/testImage/vis/A_010.bmp b/testImage/vis/A_010.bmp
new file mode 100644
index 0000000..5e66a2e
Binary files /dev/null and b/testImage/vis/A_010.bmp differ
diff --git a/testImage/vis/A_011.bmp b/testImage/vis/A_011.bmp
new file mode 100644
index 0000000..b91b350
Binary files /dev/null and b/testImage/vis/A_011.bmp differ
diff --git a/testImage/vis/A_012.bmp b/testImage/vis/A_012.bmp
new file mode 100644
index 0000000..6ab16f7
Binary files /dev/null and b/testImage/vis/A_012.bmp differ
diff --git a/testImage/vis/A_013.bmp b/testImage/vis/A_013.bmp
new file mode 100644
index 0000000..e440e86
Binary files /dev/null and b/testImage/vis/A_013.bmp differ
diff --git a/testImage/vis/A_014.bmp b/testImage/vis/A_014.bmp
new file mode 100644
index 0000000..afc7195
Binary files /dev/null and b/testImage/vis/A_014.bmp differ
diff --git a/testImage/vis/A_015.bmp b/testImage/vis/A_015.bmp
new file mode 100644
index 0000000..75407b4
Binary files /dev/null and b/testImage/vis/A_015.bmp differ
diff --git a/testImage/vis/A_016.bmp b/testImage/vis/A_016.bmp
new file mode 100644
index 0000000..689c662
Binary files /dev/null and b/testImage/vis/A_016.bmp differ
diff --git a/testImage/vis/A_017.bmp b/testImage/vis/A_017.bmp
new file mode 100644
index 0000000..55bdb87
Binary files /dev/null and b/testImage/vis/A_017.bmp differ
diff --git a/testImage/vis/A_018.bmp b/testImage/vis/A_018.bmp
new file mode 100644
index 0000000..d4601c0
Binary files /dev/null and b/testImage/vis/A_018.bmp differ
diff --git a/testImage/vis/A_019.bmp b/testImage/vis/A_019.bmp
new file mode 100644
index 0000000..1cef777
Binary files /dev/null and b/testImage/vis/A_019.bmp differ
diff --git a/testImage/vis/A_020.bmp b/testImage/vis/A_020.bmp
new file mode 100644
index 0000000..ab916d4
Binary files /dev/null and b/testImage/vis/A_020.bmp differ
diff --git a/testImage/vis/A_021.bmp b/testImage/vis/A_021.bmp
new file mode 100644
index 0000000..3f97d79
Binary files /dev/null and b/testImage/vis/A_021.bmp differ
diff --git a/testImage/vis/A_022.bmp b/testImage/vis/A_022.bmp
new file mode 100644
index 0000000..9835c10
Binary files /dev/null and b/testImage/vis/A_022.bmp differ
diff --git a/testImage/vis/A_023.bmp b/testImage/vis/A_023.bmp
new file mode 100644
index 0000000..8f6d7e9
Binary files /dev/null and b/testImage/vis/A_023.bmp differ
diff --git a/testImage/vis/A_024.bmp b/testImage/vis/A_024.bmp
new file mode 100644
index 0000000..b9a0aa1
Binary files /dev/null and b/testImage/vis/A_024.bmp differ
diff --git a/testImage/vis/A_025.bmp b/testImage/vis/A_025.bmp
new file mode 100644
index 0000000..b4add2c
Binary files /dev/null and b/testImage/vis/A_025.bmp differ
diff --git a/testImage/vis/A_026.bmp b/testImage/vis/A_026.bmp
new file mode 100644
index 0000000..d66ec7b
Binary files /dev/null and b/testImage/vis/A_026.bmp differ
diff --git a/testImage/vis/A_027.bmp b/testImage/vis/A_027.bmp
new file mode 100644
index 0000000..70a3639
Binary files /dev/null and b/testImage/vis/A_027.bmp differ
diff --git a/testImage/vis/A_028.bmp b/testImage/vis/A_028.bmp
new file mode 100644
index 0000000..3c90f9d
Binary files /dev/null and b/testImage/vis/A_028.bmp differ
diff --git a/testImage/vis/A_029.bmp b/testImage/vis/A_029.bmp
new file mode 100644
index 0000000..585c877
Binary files /dev/null and b/testImage/vis/A_029.bmp differ
diff --git a/testImage/vis/A_030.bmp b/testImage/vis/A_030.bmp
new file mode 100644
index 0000000..796b7cc
Binary files /dev/null and b/testImage/vis/A_030.bmp differ
diff --git a/testImage/vis/A_031.bmp b/testImage/vis/A_031.bmp
new file mode 100644
index 0000000..229ff71
Binary files /dev/null and b/testImage/vis/A_031.bmp differ
diff --git a/testImage/vis/A_032.bmp b/testImage/vis/A_032.bmp
new file mode 100644
index 0000000..2e0d637
Binary files /dev/null and b/testImage/vis/A_032.bmp differ
diff --git a/testImage/vis/A_033.bmp b/testImage/vis/A_033.bmp
new file mode 100644
index 0000000..9929154
Binary files /dev/null and b/testImage/vis/A_033.bmp differ
diff --git a/testImage/vis/A_034.bmp b/testImage/vis/A_034.bmp
new file mode 100644
index 0000000..517bbae
Binary files /dev/null and b/testImage/vis/A_034.bmp differ
diff --git a/testImage/vis/A_035.bmp b/testImage/vis/A_035.bmp
new file mode 100644
index 0000000..9d66070
Binary files /dev/null and b/testImage/vis/A_035.bmp differ
diff --git a/testImage/vis/A_036.bmp b/testImage/vis/A_036.bmp
new file mode 100644
index 0000000..fa047c1
Binary files /dev/null and b/testImage/vis/A_036.bmp differ
diff --git a/testImage/vis/A_037.bmp b/testImage/vis/A_037.bmp
new file mode 100644
index 0000000..e4239d7
Binary files /dev/null and b/testImage/vis/A_037.bmp differ
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..b2ed896
--- /dev/null
+++ b/train.py
@@ -0,0 +1,167 @@
+import os
+import random
+import sys
+import time
+import glob
+import numpy as np
+import torch
+from PIL import Image
+
+import utils
+import logging
+import argparse
+import torch.nn as nn
+import torch.utils
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+from torch.autograd import Variable
+from model import Encoder, Decoder
+import pytorch_msssim
+import torchvision.transforms as transforms
+import genotypes
+parser = argparse.ArgumentParser("untitled")
+
+parser.add_argument('--batch_size', type=int, default=4, help='batch size')
+parser.add_argument('--learning_rate', type=float, default=1e-4, help='init learning rate') #0.025-->2e-4
+# parser.add_argument('--learning_rate_min', type=float, default=1e-5, help='min learning rate') #0.001-->1e-4
+# parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
+# parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
+# parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
+parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
+parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
+parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
+parser.add_argument('--layers', type=int, default=2, help='total number of layers')
+
+parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
+parser.add_argument('--save', type=str, default='EXP', help='experiment name')
+parser.add_argument('--seed', type=int, default=2, help='random seed')
+# parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
+parser.add_argument('--dataset1', type=str, default=r'C:\Users\ADMIN\Desktop\DATA\data128\crop_infrared', help='Infrared images for training')
+parser.add_argument('--dataset2', type=str, default=r'C:\Users\ADMIN\Desktop\DATA\data128\crop_visible', help='Visible images for training')
+
+args = parser.parse_args()
+args.save = 'train{}'.format(time.strftime("%Y%m%d-%H%M%S"))
+utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
+log_format = '%(asctime)s %(message)s'
+logging.basicConfig(stream=sys.stdout, level=logging.INFO,
+ format=log_format, datefmt='%m/%d %I:%M:%S %p')
+fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
+fh.setFormatter(logging.Formatter(log_format))
+logging.getLogger().addHandler(fh)
+
+
+def main():
+ if not torch.cuda.is_available():
+ logging.info('no gpu device available')
+ sys.exit(1)
+ np.random.seed(args.seed)
+ torch.cuda.set_device(args.gpu)
+ cudnn.benchmark = True # 加速
+ torch.manual_seed(args.seed) # 为CUP设置随机种子
+ cudnn.enabled = True # 使用非确定性算法优化运行
+ torch.cuda.manual_seed(args.seed) # 为GPU设置随机种子
+ logging.info('gpu device = %d' % args.gpu)
+ logging.info("args = %s", args)
+ mse_loss = torch.nn.MSELoss().cuda()
+ ssim_loss = pytorch_msssim.msssim
+
+ genotype_en1 = eval('genotypes.%s' % 'genotype_en1')
+ genotype_en2 = eval('genotypes.%s' % 'genotype_en2')
+
+ genotype2 = eval('genotypes.%s' % 'genotype_de')
+
+ encoder1 = Encoder(args.init_channels, args.layers, genotype_en1).cuda()
+ encoder2 = Encoder(args.init_channels, args.layers, genotype_en2).cuda()
+
+ decoder = Decoder(args.init_channels, args.layers, genotype2).cuda()
+
+
+ # logging.info("param size = %fMB", utils.count_parameters_in_MB(encoder1)*3)
+
+ para1 = [{'params': encoder1.parameters(), 'lr': args.learning_rate},
+ {'params': decoder.parameters(), 'lr': args.learning_rate}]
+ para2 = [{'params': encoder2.parameters(), 'lr': args.learning_rate},
+ {'params': decoder.parameters(), 'lr': args.learning_rate}]
+ optimizer1 = torch.optim.Adam(para1, args.learning_rate)
+ optimizer2 = torch.optim.Adam(para2, args.learning_rate)
+
+ epochs = args.epochs
+ Infrared_path_list = utils.list_images(args.dataset1)
+ Visible_path_list = utils.list_images(args.dataset2)
+ random.shuffle(Infrared_path_list)
+ random.shuffle(Visible_path_list)
+ train_num = 15000
+
+ Infrared_path_list = Infrared_path_list[:train_num]
+ Visible_path_list = Visible_path_list[:train_num]
+ train_queue1, batches = utils.load_dataset(Infrared_path_list, args.batch_size) # infrared train
+ train_queue2, batches = utils.load_dataset(Visible_path_list, args.batch_size) # infrared train
+ train_queue12 = [train_queue1, train_queue2]
+ optimizer12 = [optimizer1, optimizer2]
+ encoder12 = [encoder1, encoder2]
+ print("len of(infrared_train_queue):", len(train_queue1)*2)
+
+ # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer3, gamma=0.9)
+ for epoch in range(epochs):
+ # lr = scheduler.get_last_lr()
+ # logging.info('epoch %d lr %e', epoch, lr[0])
+
+ # training
+
+ train(train_queue12, batches, args, encoder12, decoder, mse_loss, ssim_loss, optimizer12, epoch)
+
+ if (epoch+1)%5==0:
+ utils.save(encoder1, os.path.join(args.save, 'encoder1_epoch'+str(epoch+1)+'.pt'))
+ utils.save(encoder2, os.path.join(args.save, 'encoder2_epoch'+str(epoch+1)+'.pt'))
+ utils.save(decoder, os.path.join(args.save, 'decoder_epoch'+str(epoch+1)+'.pt'))
+
+
+ # scheduler.step()
+
+
+tensor_to_pil = transforms.ToPILImage()
+
+
+def train(train_queue_IV, batches, args, encoder12, decoder, mse_loss, ssim_loss, optimizer12, epoch):
+ encoder12[0].train()
+ encoder12[1].train()
+ decoder.train()
+ for batch in range(batches):
+ for i, train_queue, encoder, optimizer in zip(range(2), train_queue_IV, encoder12, optimizer12):
+
+ image_paths_train = train_queue[batch * args.batch_size:(batch * args.batch_size + args.batch_size)] # 训练一批
+
+ inputs = utils.get_train_images_auto(image_paths_train).cuda()
+
+ en1, en2 = encoder(inputs)
+ outputs = decoder(en1, en2)
+
+
+ optimizer.zero_grad()
+
+ ssim_loss_value = 0.
+ pixel_loss_value = 0.
+ for output, input in zip(outputs, inputs):
+ output, input = torch.unsqueeze(output, 0), torch.unsqueeze(input, 0)
+ pixel_loss_temp = mse_loss(input, output)
+ ssim_loss_temp = ssim_loss(input, output, normalize=True, val_range=255)
+ ssim_loss_value += (1 - ssim_loss_temp)
+ pixel_loss_value += pixel_loss_temp
+ ssim_loss_value /= len(outputs)
+ pixel_loss_value /= len(outputs)
+
+ total_loss = pixel_loss_value + 100*ssim_loss_value # 加权?
+ # total_loss = torch.tensor(total_loss, dtype=torch.float)
+ total_loss.backward()
+ # nn.utils.clip_grad_norm_(model_former.parameters(), args.grad_clip)
+ # nn.utils.clip_grad_value_(model_former.parameters(), args.grad_clip)
+ # nn.utils.clip_grad_value_(model_latter.parameters(), args.grad_clip)
+ optimizer.step()
+ if i==0:
+ logging.info("Infrared_epoch: %d batch: %d loss: %f", epoch, batch, total_loss)
+ else:
+ logging.info("Visible_epoch: %d batch: %d loss: %f", epoch, batch, total_loss)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/trainJT/decoder_epoch50.pt b/trainJT/decoder_epoch50.pt
new file mode 100644
index 0000000..1f3b71c
Binary files /dev/null and b/trainJT/decoder_epoch50.pt differ
diff --git a/trainJT/encoder1_epoch50.pt b/trainJT/encoder1_epoch50.pt
new file mode 100644
index 0000000..686548b
Binary files /dev/null and b/trainJT/encoder1_epoch50.pt differ
diff --git a/trainJT/encoder2_epoch50.pt b/trainJT/encoder2_epoch50.pt
new file mode 100644
index 0000000..9144fa8
Binary files /dev/null and b/trainJT/encoder2_epoch50.pt differ
diff --git a/train_search.py b/train_search.py
new file mode 100644
index 0000000..cd39d66
--- /dev/null
+++ b/train_search.py
@@ -0,0 +1,205 @@
+import os
+import random
+import sys
+import time
+import glob
+import numpy as np
+import torch
+from PIL import Image
+
+import utils
+import logging
+import argparse
+import torch.nn as nn
+import torch.utils
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+from model_search import Encoder, Decoder
+from architect import Architect
+import pytorch_msssim
+import torchvision.transforms as transforms
+
+parser = argparse.ArgumentParser("untitled")
+
+parser.add_argument('--batch_size', type=int, default=4, help='batch size') # 64改成了4
+parser.add_argument('--learning_rate', type=float, default=1e-5, help='init learning rate')
+parser.add_argument('--learning_rate_min', type=float, default=1e-6, help='min learning rate')
+parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
+parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
+parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
+parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
+parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
+parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
+parser.add_argument('--layers', type=int, default=2, help='total number of layers')
+parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
+parser.add_argument('--save', type=str, default='EXP', help='experiment name')
+parser.add_argument('--seed', type=int, default=2, help='random seed')
+parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
+parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
+parser.add_argument('--arch_learning_rate', type=float, default=4e-4, help='learning rate for arch encoding') # 3e-4
+parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') # 1e-3
+parser.add_argument('--dataset1', type=str, default=r'C:\Users\ADMIN\Desktop\search_U2F\data_vis_ir64_160\crop_infrared160', help='Infrared images for training')
+parser.add_argument('--dataset2', type=str, default=r'C:\Users\ADMIN\Desktop\search_U2F\data_vis_ir64_160\crop_visible160', help='Visible images for training')
+args = parser.parse_args()
+args.save = 'search-C{}-B{}-{}'.format(args.init_channels, args.batch_size, time.strftime("%Y%m%d-%H%M%S"))
+utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
+log_format = '%(asctime)s %(message)s'
+logging.basicConfig(stream=sys.stdout, level=logging.INFO,
+ format=log_format, datefmt='%m/%d %I:%M:%S %p')
+fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
+fh.setFormatter(logging.Formatter(log_format))
+logging.getLogger().addHandler(fh)
+os.mkdir(args.save+'/output')
+
+
+def main():
+ if not torch.cuda.is_available():
+ logging.info('no gpu device available')
+ sys.exit(1)
+ np.random.seed(args.seed)
+ torch.cuda.set_device(args.gpu)
+ cudnn.benchmark = True # 加速
+ torch.manual_seed(args.seed) # 为CUP设置随机种子
+ cudnn.enabled = True # 使用非确定性算法优化运行
+ torch.cuda.manual_seed(args.seed) # 为GPU设置随机种子
+ logging.info('gpu device = %d' % args.gpu)
+ logging.info("args = %s", args)
+ mse_loss = torch.nn.MSELoss().cuda()
+ ssim_loss = pytorch_msssim.msssim
+
+ encoder1 = Encoder(args.init_channels, args.layers).cuda()
+ encoder2 = Encoder(args.init_channels, args.layers).cuda()
+
+ decoder = Decoder(args.init_channels, args.layers).cuda()
+
+ # logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
+ para1 = [{'params': encoder1.parameters(), 'lr': args.learning_rate},
+ {'params': decoder.parameters(), 'lr': args.learning_rate}]
+ para2 = [{'params': encoder2.parameters(), 'lr': args.learning_rate},
+ {'params': decoder.parameters(), 'lr': args.learning_rate}]
+ optimizer1 = torch.optim.SGD(
+ para1,
+ args.learning_rate,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
+ optimizer2 = torch.optim.SGD(
+ para2,
+ args.learning_rate,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
+
+ epochs = args.epochs
+ # 加载数据集
+ Infrared_path_list = utils.list_images(args.dataset1)
+ Visible_path_list = utils.list_images(args.dataset2)
+ # train_num = len(total_path_list)//2
+ random.shuffle(Infrared_path_list)
+ random.shuffle(Visible_path_list)
+ # train_num = 15000
+ train_num = 15000
+
+ Infrared_path_list = Infrared_path_list[:train_num]
+ Visible_path_list = Visible_path_list[:train_num]
+
+ train_queue1, batches = utils.load_dataset(Infrared_path_list[:train_num//2], args.batch_size)
+ valid_queue1, batches = utils.load_dataset(Infrared_path_list[train_num//2:train_num], args.batch_size)
+
+ train_queue2, batches = utils.load_dataset(Visible_path_list[:train_num//2], args.batch_size)
+ valid_queue2, batches = utils.load_dataset(Visible_path_list[train_num//2:train_num], args.batch_size)
+
+
+ scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR(
+ optimizer1, args.epochs, eta_min=args.learning_rate_min)
+ scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(
+ optimizer2, args.epochs, eta_min=args.learning_rate_min)
+ architect1 = Architect(encoder1, decoder, args, mse_loss, ssim_loss)
+ architect2 = Architect(encoder2, decoder, args, mse_loss, ssim_loss)
+
+ train_queue = [train_queue1, train_queue2]
+ valid_queue = [valid_queue1, valid_queue2]
+ encoder = [encoder1, encoder2]
+ optimizer = [optimizer1, optimizer2]
+ architect = [architect1, architect2]
+ for epoch in range(epochs):
+ # lr = scheduler.get_lr()[0]
+ lr1 = scheduler1.get_last_lr()
+ lr2 = scheduler2.get_last_lr()
+
+ logging.info('epoch %d lr1 %e lr2 %e', epoch, lr1[0], lr2[0])
+
+ genotype_en1 = encoder1.genotype()
+ genotype_en2 = encoder2.genotype()
+ genotype_de = decoder.genotype()
+ logging.info('genotype_en1 = %s', genotype_en1)
+ logging.info('genotype_en2 = %s', genotype_en2)
+ logging.info('genotype_de = %s', genotype_de)
+
+ print(F.softmax(encoder1.alphas, dim=-1))
+ print(F.softmax(encoder2.alphas, dim=-1))
+ print(F.softmax(decoder.alphas, dim=-1))
+
+ logging.info('en1 = %s', F.softmax(encoder1.alphas, dim=-1))
+ logging.info('en2 = %s', F.softmax(encoder2.alphas, dim=-1))
+ logging.info('de = %s', F.softmax(decoder.alphas, dim=-1))
+
+ # training
+
+ train(train_queue, valid_queue, batches, encoder, decoder, architect, mse_loss, ssim_loss, optimizer, [lr1, lr2], epoch)
+ if (epoch+1) % 5 == 0:
+ utils.save(encoder1, os.path.join(args.save, 'encoder1_epoch'+str(epoch+1)+'.pt'))
+ utils.save(encoder2, os.path.join(args.save, 'encoder2_epoch'+str(epoch+1)+'.pt'))
+ utils.save(decoder, os.path.join(args.save, 'decoder_epoch'+str(epoch+1)+'.pt'))
+
+ scheduler1.step()
+ scheduler2.step()
+
+
+tensor_to_pil = transforms.ToPILImage()
+
+
+def train(train_queue12, valid_queue12, batches, encoder12, decoder, architect12, mse_loss, ssim, optimizer12, lr12, epoch):
+ encoder12[0].train()
+ encoder12[1].train()
+ decoder.train()
+ for batch in range(batches):
+ for i, train_queue, valid_queue, architect, optimizer, encoder, lr in zip(range(2), train_queue12, valid_queue12, architect12, optimizer12, encoder12, lr12):
+ image_paths_train = train_queue[batch * args.batch_size:(batch * args.batch_size + args.batch_size)] # 训练一批
+ inputs = utils.get_train_images_auto(image_paths_train).cuda() # 取出一批图片并且变成张量
+
+ image_paths_valid = valid_queue[batch * args.batch_size:(batch * args.batch_size + args.batch_size)] # 训练一批
+ inputs_search = utils.get_train_images_auto(image_paths_valid).cuda() # 取出一批图片并且变成张量
+
+ architect.step(inputs, inputs_search, lr, optimizer, unrolled=args.unrolled)
+
+ print(F.softmax(encoder.alphas, dim=-1))
+ print(F.softmax(decoder.alphas, dim=-1))
+
+ en1, en2 = encoder(inputs)
+ outputs = decoder(en1, en2)
+
+
+ optimizer.zero_grad()
+
+ ssim_loss_value = 0.
+ pixel_loss_value = 0.
+ for output, input in zip(outputs, inputs):
+ output, input = torch.unsqueeze(output, 0), torch.unsqueeze(input, 0)
+ pixel_loss_temp = mse_loss(input, output)
+ ssim_loss_temp = ssim(input, output, normalize=True, val_range=255)
+ ssim_loss_value += (1 - ssim_loss_temp)
+ pixel_loss_value += pixel_loss_temp
+ ssim_loss_value /= len(outputs)
+ pixel_loss_value /= len(outputs)
+ total_loss = pixel_loss_value + 100*ssim_loss_value
+ total_loss.backward()
+ nn.utils.clip_grad_value_(encoder.parameters(), args.grad_clip)
+ nn.utils.clip_grad_value_(decoder.parameters(), args.grad_clip)
+ optimizer.step()
+ if i==0:
+ logging.info("Infrared_epoch: %d batch: %d loss: %f", epoch, batch+1, total_loss)
+ else:
+ logging.info("Visible_epoch: %d batch: %d loss: %f", epoch, batch+1, total_loss)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..a3e44ab
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,272 @@
+import os
+import random
+from os import listdir
+from os.path import join
+
+import numpy as np
+import torch
+import shutil
+import torchvision.transforms as transforms
+from torch.autograd import Variable
+from PIL import Image
+
+from imageio import imread, imsave
+
+
+class AvgrageMeter(object):
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.avg = 0
+ self.sum = 0
+ self.cnt = 0
+
+ def update(self, val, n=1):
+ self.sum += val * n
+ self.cnt += n
+ self.avg = self.sum / self.cnt
+
+
+def accuracy(output, target, topk=(1,)):
+ maxk = max(topk)
+ batch_size = target.size(0)
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+class Cutout(object):
+ def __init__(self, length):
+ self.length = length
+
+ def __call__(self, img):
+ h, w = img.size(1), img.size(2)
+ mask = np.ones((h, w), np.float32)
+ y = np.random.randint(h)
+ x = np.random.randint(w)
+ y1 = np.clip(y - self.length // 2, 0, h)
+ y2 = np.clip(y + self.length // 2, 0, h)
+ x1 = np.clip(x - self.length // 2, 0, w)
+ x2 = np.clip(x + self.length // 2, 0, w)
+ mask[y1: y2, x1: x2] = 0.
+ mask = torch.from_numpy(mask)
+ mask = mask.expand_as(img)
+ img *= mask
+ return img
+
+
+def _data_transforms_cifar10(args):
+ CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
+ CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
+ train_transform = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
+ ])
+ if args.cutout:
+ train_transform.transforms.append(Cutout(args.cutout_length))
+ valid_transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
+ ])
+
+ return train_transform, valid_transform
+
+
+# load training images
+def load_dataset(image_path, BATCH_SIZE, num_imgs=None):
+ if num_imgs is None:
+ num_imgs = len(image_path)
+ original_imgs_path = image_path[:num_imgs]
+ # random
+ # random.shuffle(original_imgs_path)
+ mod = num_imgs % BATCH_SIZE
+ # print('BATCH SIZE %d.' % BATCH_SIZE)
+ # print('Train images number %d.' % num_imgs)
+ # print('Train images samples %s.' % str(num_imgs / BATCH_SIZE))
+
+ if mod > 0:
+ print('Train set has been trimmed %d samples...\n' % mod)
+ original_imgs_path = original_imgs_path[:-mod] # 多出来的不用处理?
+ batches = int(len(original_imgs_path) // BATCH_SIZE)
+ return original_imgs_path, batches
+
+def get_image(path, height=256, width=256, mode='L'):
+ if mode == 'L':
+ image = imread(path, pilmode=mode)
+ # image = Image.open(path, )
+ elif mode == 'RGB':
+ image = Image.open(path).convert('RGB')
+
+ if height is not None and width is not None:
+ #image = imresize(image, [height, width], interp='nearest')
+ image = np.array(Image.fromarray(image).resize((height, width)))
+ return image
+
+
+'''
+def get_train_images_auto(paths, height=256, width=256, mode='RGB'):
+ if isinstance(paths, str):
+ paths = [paths] #加个[]?
+ images = []
+ for path in paths:
+ image = get_image(path, height, width, mode=mode)
+ if mode == 'L': #什么模式
+ image = np.reshape(image, [1, image.shape[0], image.shape[1]])#灰度图片
+ else:
+ image = np.reshape(image, [image.shape[2], image.shape[0], image.shape[1]])#RGB
+ images.append(image)
+
+ images = np.stack(images, axis=0) #增加一维
+ images = torch.from_numpy(images).float()
+ return images
+'''
+
+pil_to_tensor = transforms.ToTensor()
+
+def get_train_images_auto(paths):
+ if isinstance(paths, str):
+ paths = [paths] # 加个[]?
+ images = []
+ for path in paths:
+ image = Image.open(path)
+ mode = image.mode
+ if mode == 'RGB': # 什么模式
+ image = image.convert('L')
+ image = np.reshape(image, [1, image.size[1], image.size[0]]) # 灰度图片
+ # image = pil_to_tensor(image)
+ # print(image.shape)
+ images.append(image)
+ images = np.stack(images, axis=0) # 增加一维
+ images = torch.from_numpy(images).float()
+ # images = torch.stack(images)
+ # print(images.size())
+ return images
+
+#0-255
+# def get_train_images_auto(paths, height=256, width=256, mode='RGB'):
+# if isinstance(paths, str):
+# paths = [paths] #加个[]?
+# images = []
+# for path in paths:
+# image = get_image(path, height, width, mode=mode)
+# if mode == 'L': #什么模式
+# image = np.reshape(image, [1, image.shape[0], image.shape[1]])#灰度图片
+# else:
+# image = np.reshape(image, [image.shape[2], image.shape[0], image.shape[1]])#RGB
+# images.append(image)
+#
+# images = np.stack(images, axis=0) #增加一维
+# images = torch.from_numpy(images).float()
+# return images
+
+
+def get_test_images(paths, height=None, width=None, mode='L'):
+ ImageToTensor = transforms.Compose([transforms.ToTensor()])
+ if isinstance(paths, str):
+ paths = [paths]
+ images = []
+ for path in paths:#不就一张图片吗
+ image = get_image(path, height, width, mode=mode)
+ if mode == 'L':
+ image = np.reshape(image, [1, image.shape[0], image.shape[1]])
+ else:
+ # test = ImageToTensor(image).numpy()
+ # shape = ImageToTensor(image).size()
+ image = ImageToTensor(image).float().numpy()*255
+ images.append(image)
+ images = np.stack(images, axis=0)
+ images = torch.from_numpy(images).float()
+ return images
+
+
+'''
+loader = transforms.Compose([
+ transforms.ToTensor()])
+
+
+def get_train_images_auto(paths):
+ if isinstance(paths, str):
+ paths = [paths] # 加个[]?
+ images = []
+ for path in paths:
+ image = Image.open(path)
+ # print(image.size)
+ mode = image.mode
+ if mode == 'RGB': # 什么模式
+ image = image.convert('L')
+ img = loader(image)
+ images.append(img)
+ images = torch.stack(images) # 增加一维
+ # print(images)
+ return images
+'''
+
+
+def list_images(directory): # 得到所有图片路径
+ images = []
+ names = []
+ dir = listdir(directory)
+ dir.sort()
+ for file in dir:
+ name = file.lower()
+ if name.endswith('.png'):
+ images.append(join(directory, file))
+ elif name.endswith('.jpg'):
+ images.append(join(directory, file))
+ elif name.endswith('.jpeg'):
+ images.append(join(directory, file))
+ name1 = name.split('.')
+ names.append(name1[0]) # names没有用到
+ return images
+
+
+def count_parameters_in_MB(model):
+ return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
+
+
+def save_checkpoint(state, is_best, save):
+ filename = os.path.join(save, 'checkpoint.pth.tar')
+
+ torch.save(state, filename)
+
+ if is_best:
+ best_filename = os.path.join(save, 'model_best.pth.tar')
+
+ shutil.copyfile(filename, best_filename)
+
+
+def save(model, model_path):
+ torch.save(model.state_dict(), model_path)
+
+
+def load(model, model_path):
+ model.load_state_dict(torch.load(model_path))
+
+
+def drop_path(x, drop_prob):
+ if drop_prob > 0.:
+ keep_prob = 1. - drop_prob
+ mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
+ print('x.size:', x.shape)
+ x.div_(keep_prob)
+ x.mul_(mask)
+ return x
+
+
+def create_exp_dir(path, scripts_to_save=None):
+ if not os.path.exists(path):
+ os.mkdir(path)
+ print('Experiment dir : {}'.format(path))
+ if scripts_to_save is not None:
+ os.mkdir(os.path.join(path, 'scripts'))
+ for script in scripts_to_save:
+ dst_file = os.path.join(path, 'scripts', os.path.basename(script))
+ shutil.copyfile(script, dst_file)