Skip to content

Commit

Permalink
upload code
Browse files Browse the repository at this point in the history
  • Loading branch information
JinyuanLiu-CV committed Jan 23, 2022
0 parents commit 25c8191
Show file tree
Hide file tree
Showing 137 changed files with 1,880 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
8 changes: 8 additions & 0 deletions .idea/code.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

161 changes: 161 additions & 0 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 53 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Introduction

This is the implementation of the paper [SMoA: Searching a Modality-Oriented Architecture for Infrared and Visible Image Fusion](https://ieeexplore.ieee.org/abstract/document/9528046).

## Requirements

* python >= 3.6
* pytorch == 1.7
* torchvision == 0.8

## Datasets

You can download the datasets [here](https://pan.baidu.com/s/1kUja4iau37MwLnGI8_lMWg?pwd=eapv).

## Test

```shell
python test.py
```

## Train from scratch

### step 1

```shell
python train_search.py
```

### step 2

Find the string which descripting the searched architectures in the log file. Copy and paste it into the genotypes.py, the format should consist with the primary architecture string.

### step 3

```shell
python train.py
```

## Citation

If you use any part of this code in your research, please cite our [paper](https://ieeexplore.ieee.org/abstract/document/9528046):

```
@ARTICLE{9528046,
author={Liu, Jinyuan and Wu, Yuhui and Huang, Zhanbo and Liu, Risheng and Fan, Xin},
journal={IEEE Signal Processing Letters},
title={SMoA: Searching a Modality-Oriented Architecture for Infrared and Visible Image Fusion},
year={2021},
volume={28},
number={},
pages={1818-1822},
doi={10.1109/LSP.2021.3109818}}
```
Binary file added Strategy_vsm/A_001.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_002.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_003.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_004.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_005.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_006.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_007.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_008.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_009.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_010.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_011.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_012.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_013.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_014.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_015.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_016.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_017.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_018.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_019.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_020.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_021.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_022.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_023.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_024.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_025.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Strategy_vsm/A_026.jpg
Binary file added Strategy_vsm/A_027.jpg
Binary file added Strategy_vsm/A_028.jpg
Binary file added Strategy_vsm/A_029.jpg
Binary file added Strategy_vsm/A_030.jpg
Binary file added Strategy_vsm/A_031.jpg
Binary file added Strategy_vsm/A_032.jpg
Binary file added Strategy_vsm/A_033.jpg
Binary file added Strategy_vsm/A_034.jpg
Binary file added Strategy_vsm/A_035.jpg
Binary file added Strategy_vsm/A_036.jpg
Binary file added Strategy_vsm/A_037.jpg
Binary file added __pycache__/genotypes.cpython-36.pyc
Binary file not shown.
Binary file added __pycache__/model.cpython-36.pyc
Binary file not shown.
Binary file added __pycache__/operations.cpython-36.pyc
Binary file not shown.
Binary file added __pycache__/utils.cpython-36.pyc
Binary file not shown.
103 changes: 103 additions & 0 deletions architect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable


def _concat(xs):
return torch.cat([x.view(-1) for x in xs])


class Architect(object):

def __init__(self, model_former, model_latter, args, mse_loss, ssim_loss):
self.network_momentum = args.momentum
self.network_weight_decay = args.weight_decay
self.model_former = model_former
self.model_latter = model_latter
self.mse_loss = mse_loss
self.ssim_loss = ssim_loss
para = [{'params': model_former.arch_parameters(), 'lr': args.arch_learning_rate},
{'params': model_latter.arch_parameters(), 'lr': args.arch_learning_rate}]
self.optimizer = torch.optim.Adam(para,
lr=args.arch_learning_rate, betas=(0.5, 0.999),
weight_decay=args.arch_weight_decay)

def _compute_unrolled_model(self, input, eta, network_optimizer):
loss = self.model._loss(input)
theta = _concat(self.model.parameters()).data
try:
moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(
self.network_momentum)
except:
moment = torch.zeros_like(theta)
dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay * theta
unrolled_model = self._construct_model_from_theta(theta.sub(moment + dtheta, alpha=eta[0]))
return unrolled_model

def step(self, input_train, input_valid, eta, network_optimizer, unrolled):
self.optimizer.zero_grad()
if unrolled:
self._backward_step_unrolled(input_train, input_valid, eta, network_optimizer)
else:
self._backward_step(input_valid)
self.optimizer.step()

def _backward_step(self, input_valid):
en1, en2 = self.model_former(input_valid) # ######
output_valid = self.model_latter(en1, en2)
ssim_loss_value = 0.
pixel_loss_value = 0.
for output, input in zip(output_valid, input_valid):
output, input = torch.unsqueeze(output, 0), torch.unsqueeze(input, 0)
pixel_loss_temp = self.mse_loss(input, output)
ssim_loss_temp = self.ssim_loss(input, output, normalize=True, val_range=255)
ssim_loss_value += (1 - ssim_loss_temp)
pixel_loss_value += pixel_loss_temp
ssim_loss_value /= len(output_valid)
pixel_loss_value /= len(output_valid)

total_loss = pixel_loss_value + 100*ssim_loss_value # 加权?
total_loss.backward()

def _backward_step_unrolled(self, input_train, input_valid, eta, network_optimizer):
unrolled_model = self._compute_unrolled_model(input_train, eta, network_optimizer)
unrolled_loss = unrolled_model._loss(input_valid)
unrolled_loss.backward()
dalpha = [v.grad for v in unrolled_model.arch_parameters()]
vector = [v.grad.data for v in unrolled_model.parameters()]
implicit_grads = self._hessian_vector_product(vector, input_train)
for g, ig in zip(dalpha, implicit_grads):
g.data.sub_(ig.data, alpha=eta[0])
for v, g in zip(self.model.arch_parameters(), dalpha):
if v.grad is None:
v.grad = Variable(g.data)
else:
v.grad.data.copy_(g.data)

def _construct_model_from_theta(self, theta):
model_new = self.model.new()
model_dict = self.model.state_dict()
params, offset = {}, 0
for k, v in self.model.named_parameters():
v_length = np.prod(v.size())
params[k] = theta[offset: offset + v_length].view(v.size())
offset += v_length
assert offset == len(theta)
model_dict.update(params)
model_new.load_state_dict(model_dict)
return model_new.cuda()

def _hessian_vector_product(self, vector, input, r=1e-2):
R = r / _concat(vector).norm()
for p, v in zip(self.model.parameters(), vector):
p.data.add_(v, alpha=R)
loss = self.model._loss(input)
grads_p = torch.autograd.grad(loss, self.model.arch_parameters())
for p, v in zip(self.model.parameters(), vector):
p.data.sub_(v, alpha=2 * R)
loss = self.model._loss(input)
grads_n = torch.autograd.grad(loss, self.model.arch_parameters())
for p, v in zip(self.model.parameters(), vector):
p.data.add_(v, alpha=R)
return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)]
19 changes: 19 additions & 0 deletions genotypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from collections import namedtuple

Genotype = namedtuple('Genotype', 'cell cell_concat')

PRIMITIVES = [
'none',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'Spatialattention',
'Denseblocks',
'Residualblocks'
]

# epoch 20
genotype_en1 = Genotype(cell=[('sep_conv_3x3', 0), ('Spatialattention', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_5x5', 0), ('sep_conv_5x5', 3), ('Denseblocks', 4), ('Denseblocks', 3)], cell_concat=range(2, 6))
genotype_en2 = Genotype(cell=[('Spatialattention', 1), ('dil_conv_3x3', 0), ('dil_conv_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_3x3', 0), ('Spatialattention', 2), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], cell_concat=range(2, 6))
genotype_de = Genotype(cell=[('sep_conv_3x3', 0), ('Denseblocks', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 0), ('sep_conv_5x5', 2), ('sep_conv_3x3', 0), ('Denseblocks', 4), ('sep_conv_3x3', 0)], cell_concat=range(2, 6))
Loading

0 comments on commit 25c8191

Please sign in to comment.