forked from JinyuanLiu-CV/SMoA
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 25c8191
Showing
137 changed files
with
1,880 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Introduction | ||
|
||
This is the implementation of the paper [SMoA: Searching a Modality-Oriented Architecture for Infrared and Visible Image Fusion](https://ieeexplore.ieee.org/abstract/document/9528046). | ||
|
||
## Requirements | ||
|
||
* python >= 3.6 | ||
* pytorch == 1.7 | ||
* torchvision == 0.8 | ||
|
||
## Datasets | ||
|
||
You can download the datasets [here](https://pan.baidu.com/s/1kUja4iau37MwLnGI8_lMWg?pwd=eapv). | ||
|
||
## Test | ||
|
||
```shell | ||
python test.py | ||
``` | ||
|
||
## Train from scratch | ||
|
||
### step 1 | ||
|
||
```shell | ||
python train_search.py | ||
``` | ||
|
||
### step 2 | ||
|
||
Find the string which descripting the searched architectures in the log file. Copy and paste it into the genotypes.py, the format should consist with the primary architecture string. | ||
|
||
### step 3 | ||
|
||
```shell | ||
python train.py | ||
``` | ||
|
||
## Citation | ||
|
||
If you use any part of this code in your research, please cite our [paper](https://ieeexplore.ieee.org/abstract/document/9528046): | ||
|
||
``` | ||
@ARTICLE{9528046, | ||
author={Liu, Jinyuan and Wu, Yuhui and Huang, Zhanbo and Liu, Risheng and Fan, Xin}, | ||
journal={IEEE Signal Processing Letters}, | ||
title={SMoA: Searching a Modality-Oriented Architecture for Infrared and Visible Image Fusion}, | ||
year={2021}, | ||
volume={28}, | ||
number={}, | ||
pages={1818-1822}, | ||
doi={10.1109/LSP.2021.3109818}} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import torch | ||
import numpy as np | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
|
||
|
||
def _concat(xs): | ||
return torch.cat([x.view(-1) for x in xs]) | ||
|
||
|
||
class Architect(object): | ||
|
||
def __init__(self, model_former, model_latter, args, mse_loss, ssim_loss): | ||
self.network_momentum = args.momentum | ||
self.network_weight_decay = args.weight_decay | ||
self.model_former = model_former | ||
self.model_latter = model_latter | ||
self.mse_loss = mse_loss | ||
self.ssim_loss = ssim_loss | ||
para = [{'params': model_former.arch_parameters(), 'lr': args.arch_learning_rate}, | ||
{'params': model_latter.arch_parameters(), 'lr': args.arch_learning_rate}] | ||
self.optimizer = torch.optim.Adam(para, | ||
lr=args.arch_learning_rate, betas=(0.5, 0.999), | ||
weight_decay=args.arch_weight_decay) | ||
|
||
def _compute_unrolled_model(self, input, eta, network_optimizer): | ||
loss = self.model._loss(input) | ||
theta = _concat(self.model.parameters()).data | ||
try: | ||
moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_( | ||
self.network_momentum) | ||
except: | ||
moment = torch.zeros_like(theta) | ||
dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay * theta | ||
unrolled_model = self._construct_model_from_theta(theta.sub(moment + dtheta, alpha=eta[0])) | ||
return unrolled_model | ||
|
||
def step(self, input_train, input_valid, eta, network_optimizer, unrolled): | ||
self.optimizer.zero_grad() | ||
if unrolled: | ||
self._backward_step_unrolled(input_train, input_valid, eta, network_optimizer) | ||
else: | ||
self._backward_step(input_valid) | ||
self.optimizer.step() | ||
|
||
def _backward_step(self, input_valid): | ||
en1, en2 = self.model_former(input_valid) # ###### | ||
output_valid = self.model_latter(en1, en2) | ||
ssim_loss_value = 0. | ||
pixel_loss_value = 0. | ||
for output, input in zip(output_valid, input_valid): | ||
output, input = torch.unsqueeze(output, 0), torch.unsqueeze(input, 0) | ||
pixel_loss_temp = self.mse_loss(input, output) | ||
ssim_loss_temp = self.ssim_loss(input, output, normalize=True, val_range=255) | ||
ssim_loss_value += (1 - ssim_loss_temp) | ||
pixel_loss_value += pixel_loss_temp | ||
ssim_loss_value /= len(output_valid) | ||
pixel_loss_value /= len(output_valid) | ||
|
||
total_loss = pixel_loss_value + 100*ssim_loss_value # 加权? | ||
total_loss.backward() | ||
|
||
def _backward_step_unrolled(self, input_train, input_valid, eta, network_optimizer): | ||
unrolled_model = self._compute_unrolled_model(input_train, eta, network_optimizer) | ||
unrolled_loss = unrolled_model._loss(input_valid) | ||
unrolled_loss.backward() | ||
dalpha = [v.grad for v in unrolled_model.arch_parameters()] | ||
vector = [v.grad.data for v in unrolled_model.parameters()] | ||
implicit_grads = self._hessian_vector_product(vector, input_train) | ||
for g, ig in zip(dalpha, implicit_grads): | ||
g.data.sub_(ig.data, alpha=eta[0]) | ||
for v, g in zip(self.model.arch_parameters(), dalpha): | ||
if v.grad is None: | ||
v.grad = Variable(g.data) | ||
else: | ||
v.grad.data.copy_(g.data) | ||
|
||
def _construct_model_from_theta(self, theta): | ||
model_new = self.model.new() | ||
model_dict = self.model.state_dict() | ||
params, offset = {}, 0 | ||
for k, v in self.model.named_parameters(): | ||
v_length = np.prod(v.size()) | ||
params[k] = theta[offset: offset + v_length].view(v.size()) | ||
offset += v_length | ||
assert offset == len(theta) | ||
model_dict.update(params) | ||
model_new.load_state_dict(model_dict) | ||
return model_new.cuda() | ||
|
||
def _hessian_vector_product(self, vector, input, r=1e-2): | ||
R = r / _concat(vector).norm() | ||
for p, v in zip(self.model.parameters(), vector): | ||
p.data.add_(v, alpha=R) | ||
loss = self.model._loss(input) | ||
grads_p = torch.autograd.grad(loss, self.model.arch_parameters()) | ||
for p, v in zip(self.model.parameters(), vector): | ||
p.data.sub_(v, alpha=2 * R) | ||
loss = self.model._loss(input) | ||
grads_n = torch.autograd.grad(loss, self.model.arch_parameters()) | ||
for p, v in zip(self.model.parameters(), vector): | ||
p.data.add_(v, alpha=R) | ||
return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from collections import namedtuple | ||
|
||
Genotype = namedtuple('Genotype', 'cell cell_concat') | ||
|
||
PRIMITIVES = [ | ||
'none', | ||
'sep_conv_3x3', | ||
'sep_conv_5x5', | ||
'dil_conv_3x3', | ||
'dil_conv_5x5', | ||
'Spatialattention', | ||
'Denseblocks', | ||
'Residualblocks' | ||
] | ||
|
||
# epoch 20 | ||
genotype_en1 = Genotype(cell=[('sep_conv_3x3', 0), ('Spatialattention', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_5x5', 0), ('sep_conv_5x5', 3), ('Denseblocks', 4), ('Denseblocks', 3)], cell_concat=range(2, 6)) | ||
genotype_en2 = Genotype(cell=[('Spatialattention', 1), ('dil_conv_3x3', 0), ('dil_conv_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_3x3', 0), ('Spatialattention', 2), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], cell_concat=range(2, 6)) | ||
genotype_de = Genotype(cell=[('sep_conv_3x3', 0), ('Denseblocks', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 0), ('sep_conv_5x5', 2), ('sep_conv_3x3', 0), ('Denseblocks', 4), ('sep_conv_3x3', 0)], cell_concat=range(2, 6)) |
Oops, something went wrong.