diff --git a/xBD_code/adamw.py b/xBD_code/adamw.py new file mode 100644 index 0000000..4bb590c --- /dev/null +++ b/xBD_code/adamw.py @@ -0,0 +1,89 @@ +# Based on https://github.com/pytorch/pytorch/pull/3740 +import torch +import math + + +class AdamW(torch.optim.Optimizer): + """Implements AdamW algorithm. + + It has been proposed in `Fixing Weight Decay Regularization in Adam`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + + .. Fixing Weight Decay Regularization in Adam: + https://arxiv.org/abs/1711.05101 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + super(AdamW, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients, please consider SparseAdam instead') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # according to the paper, this penalty should come after the bias correction + # if group['weight_decay'] != 0: + # grad = grad.add(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + # w = w - wd * lr * w + if group['weight_decay'] != 0: + p.data.add_(-group['weight_decay'] * group['lr'], p.data) + + # w = w - lr * w.grad + p.data.addcdiv_(-step_size, exp_avg, denom) + + # w = w - wd * lr * w - lr * w.grad + # See http://www.fast.ai/2018/07/02/adam-weight-decay/ + + return loss \ No newline at end of file diff --git a/xBD_code/dual_hrnet_config.yaml b/xBD_code/dual_hrnet_config.yaml new file mode 100644 index 0000000..503778b --- /dev/null +++ b/xBD_code/dual_hrnet_config.yaml @@ -0,0 +1,118 @@ +OUTPUT_DIR: '' +LOG_DIR: '' +GPUS: [0,] +WORKERS: 4 +PRINT_FREQ: 20 +AUTO_RESUME: False +PIN_MEMORY: True +RANK: 0 + +# Cudnn related params +CUDNN: + BENCHMARK: True + DETERMINISTIC: False + ENABLED: True + +# common params for NETWORK +MODEL: + NAME: 'dual-hrnet' + PRETRAINED: './Checkpoints/HRNet/hrnetv2_w32_imagenet_pretrained.pth' + USE_FPN: False + IS_DISASTER_PRED: False + IS_SPLIT_LOSS: True + FUSE_CONV_K_SIZE: 1 + + # high_resoluton_net related params for segmentation + EXTRA: + PRETRAINED_LAYERS: ['*'] + STEM_INPLANES: 64 + FINAL_CONV_KERNEL: 1 + WITH_HEAD: True + + STAGE1: + NUM_MODULES: 1 + NUM_BRANCHES: 1 + NUM_BLOCKS: [4] + NUM_CHANNELS: [64] + BLOCK: 'BOTTLENECK' + FUSE_METHOD: 'SUM' + + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + NUM_BLOCKS: [4, 4] + NUM_CHANNELS: [32, 64] + BLOCK: 'BASIC' + FUSE_METHOD: 'SUM' + + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + NUM_BLOCKS: [4, 4, 4] + NUM_CHANNELS: [32, 64, 128] + BLOCK: 'BASIC' + FUSE_METHOD: 'SUM' + + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + NUM_BLOCKS: [4, 4, 4, 4] + NUM_CHANNELS: [32, 64, 128, 256] + BLOCK: 'BASIC' + FUSE_METHOD: 'SUM' + +#_C.MODEL.EXTRA= CN(new_allowed=True) + +LOSS: + CLASS_BALANCE: True + +# DATASET related params +DATASET: + NUM_CLASSES: 4 + +# training +TRAIN: + # Augmentation + FLIP: True + MULTI_SCALE: [0.8, 1.2] + CROP_SIZE: [512, 512] + + LR_FACTOR: 0.1 + LR_STEP: [90, 110] + LR: 0.05 + EXTRA_LR: 0.001 + + OPTIMIZER: 'sgd' + MOMENTUM: 0.9 + WD: 0.0001 + NESTEROV: False + IGNORE_LABEL: -1 + + NUM_EPOCHS: 500 + RESUME: False + + BATCH_SIZE_PER_GPU: 16 + SHUFFLE: True +# only using some training samples + NUM_SAMPLES: 0 + CLASS_WEIGHTS: [0.4, 1.2, 1.2, 1.2] + +# testing +TEST: + BATCH_SIZE_PER_GPU: 32 +# only testing some samples + NUM_SAMPLES: 0 + + MODEL_FILE: '' + FLIP_TEST: False + MULTI_SCALE: False + CENTER_CROP_TEST: False + SCALE_LIST: [1] + +# debug +DEBUG: + DEBUG: False + SAVE_BATCH_IMAGES_GT: False + SAVE_BATCH_IMAGES_PRED: False + SAVE_HEATMAPS_GT: False + SAVE_HEATMAPS_PRED: False diff --git a/xBD_code/losses.py b/xBD_code/losses.py new file mode 100644 index 0000000..e8867cf --- /dev/null +++ b/xBD_code/losses.py @@ -0,0 +1,289 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Variable + +try: + from itertools import ifilterfalse +except ImportError: # py3k + from itertools import filterfalse + +eps = 1e-6 + +def dice_round(preds, trues): + preds = preds.float() + return soft_dice_loss(preds, trues) + + +def iou_round(preds, trues): + preds = preds.float() + return jaccard(preds, trues) + + +def soft_dice_loss(outputs, targets, per_image=False): + batch_size = outputs.size()[0] + if not per_image: + batch_size = 1 + dice_target = targets.contiguous().view(batch_size, -1).float() + dice_output = outputs.contiguous().view(batch_size, -1) + intersection = torch.sum(dice_output * dice_target, dim=1) + union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) + eps + loss = (1 - (2 * intersection + eps) / union).mean() + return loss + + +def jaccard(outputs, targets, per_image=False): + batch_size = outputs.size()[0] + if not per_image: + batch_size = 1 + dice_target = targets.contiguous().view(batch_size, -1).float() + dice_output = outputs.contiguous().view(batch_size, -1) + intersection = torch.sum(dice_output * dice_target, dim=1) + union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) - intersection + eps + losses = 1 - (intersection + eps) / union + return losses.mean() + + +class DiceLoss(nn.Module): + def __init__(self, weight=None, size_average=True, per_image=False): + super().__init__() + self.size_average = size_average + self.register_buffer('weight', weight) + self.per_image = per_image + + def forward(self, input, target): + return soft_dice_loss(input, target, per_image=self.per_image) + + +class JaccardLoss(nn.Module): + def __init__(self, weight=None, size_average=True, per_image=False): + super().__init__() + self.size_average = size_average + self.register_buffer('weight', weight) + self.per_image = per_image + + def forward(self, input, target): + return jaccard(input, target, per_image=self.per_image) + + +class StableBCELoss(nn.Module): + def __init__(self): + super(StableBCELoss, self).__init__() + + def forward(self, input, target): + input = input.float().view(-1) + target = target.float().view(-1) + neg_abs = - input.abs() + # todo check correctness + loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() + return loss.mean() + +class MaskLoss(nn.Module): + # Per-pixel sigmoid for each class C separately + # Average binary cross entropy loss + def __init__(self): + super(MaskLoss, self).__init__() + + def forward(self, input, target): + input = input.float().view(-1) + target = target.float().view(-1) + loss = F.binary_cross_entropy(input, target) + return loss + + +class ComboLoss(nn.Module): + def __init__(self, weights, per_image=False): + super().__init__() + self.weights = weights + self.bce = StableBCELoss() + self.dice = DiceLoss(per_image=False) + self.jaccard = JaccardLoss(per_image=False) + self.lovasz = LovaszLoss(per_image=per_image) + self.lovasz_sigmoid = LovaszLossSigmoid(per_image=per_image) + self.focal = FocalLoss2d() + self.mask_bceavg = MaskLoss() + self.mapping = {'bce': self.bce, + 'dice': self.dice, + 'focal': self.focal, + 'jaccard': self.jaccard, + 'lovasz': self.lovasz, + 'lovasz_sigmoid': self.lovasz_sigmoid, + 'mask_bceavg': self.mask_bceavg} + self.expect_sigmoid = {'dice', 'focal', 'jaccard', 'lovasz_sigmoid', 'mask_bceavg'} + self.values = {} + + def forward(self, outputs, targets): + loss = 0 + weights = self.weights + sigmoid_input = torch.sigmoid(outputs) + for k, v in weights.items(): + if not v: + continue + val = self.mapping[k](sigmoid_input if k in self.expect_sigmoid else outputs, targets) + self.values[k] = val + loss += self.weights[k] * val + return loss + + +def lovasz_grad(gt_sorted): + """ + Computes gradient of the Lovasz extension w.r.t sorted errors + See Alg. 1 in paper + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts.float() - gt_sorted.float().cumsum(0) + union = gts.float() + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def lovasz_hinge(logits, labels, per_image=True, ignore=None): + """ + Binary Lovasz hinge loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + per_image: compute the loss per image instead of per batch + ignore: void class id + """ + if per_image: + loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) + for log, lab in zip(logits, labels)) + else: + loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) + return loss + + +def lovasz_hinge_flat(logits, labels): + """ + Binary Lovasz hinge loss + logits: [P] Variable, logits at each prediction (between -\infty and +\infty) + labels: [P] Tensor, binary ground truth labels (0 or 1) + ignore: label to ignore + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * Variable(signs)) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), Variable(grad)) + return loss + + +def flatten_binary_scores(scores, labels, ignore=None): + """ + Flattens predictions in the batch (binary case) + Remove labels equal to 'ignore' + """ + scores = scores.view(-1) + labels = labels.view(-1) + if ignore is None: + return scores, labels + valid = (labels != ignore) + vscores = scores[valid] + vlabels = labels[valid] + return vscores, vlabels + + +def lovasz_sigmoid(probas, labels, per_image=False, ignore=None): + """ + Multi-class Lovasz-Softmax loss + probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1) + labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + only_present: average only on classes present in ground truth + per_image: compute the loss per image instead of per batch + ignore: void class labels + """ + if per_image: + loss = mean(lovasz_sigmoid_flat(*flatten_binary_scores(prob.unsqueeze(0), lab.unsqueeze(0), ignore)) + for prob, lab in zip(probas, labels)) + else: + loss = lovasz_sigmoid_flat(*flatten_binary_scores(probas, labels, ignore)) + return loss + + +def lovasz_sigmoid_flat(probas, labels): + """ + Multi-class Lovasz-Softmax loss + probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) + labels: [P] Tensor, ground truth labels (between 0 and C - 1) + only_present: average only on classes present in ground truth + """ + fg = labels.float() + errors = (Variable(fg) - probas).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + loss = torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))) + return loss + + +def mean(l, ignore_nan=False, empty=0): + """ + nanmean compatible with generators. + """ + l = iter(l) + if ignore_nan: + l = ifilterfalse(np.isnan, l) + try: + n = 1 + acc = next(l) + except StopIteration: + if empty == 'raise': + raise ValueError('Empty mean') + return empty + for n, v in enumerate(l, 2): + acc += v + if n == 1: + return acc + return acc / n + + +class LovaszLoss(nn.Module): + def __init__(self, ignore_index=255, per_image=True): + super().__init__() + self.ignore_index = ignore_index + self.per_image = per_image + + def forward(self, outputs, targets): + outputs = outputs.contiguous() + targets = targets.contiguous() + return lovasz_hinge(outputs, targets, per_image=self.per_image, ignore=self.ignore_index) + + +class LovaszLossSigmoid(nn.Module): + def __init__(self, ignore_index=255, per_image=True): + super().__init__() + self.ignore_index = ignore_index + self.per_image = per_image + + def forward(self, outputs, targets): + outputs = outputs.contiguous() + targets = targets.contiguous() + return lovasz_sigmoid(outputs, targets, per_image=self.per_image, ignore=self.ignore_index) + + +class FocalLoss2d(nn.Module): + def __init__(self, gamma=2, ignore_index=255): + super().__init__() + self.gamma = gamma + self.ignore_index = ignore_index + + def forward(self, outputs, targets): + outputs = outputs.contiguous() + targets = targets.contiguous() + # eps = 1e-8 + non_ignored = targets.view(-1) != self.ignore_index + targets = targets.view(-1)[non_ignored].float() + outputs = outputs.contiguous().view(-1)[non_ignored] + outputs = torch.clamp(outputs, eps, 1. - eps) + targets = torch.clamp(targets, eps, 1. - eps) + pt = (1 - targets) * (1 - outputs) + targets * outputs + return (-(1. - pt) ** self.gamma * torch.log(pt)).mean() \ No newline at end of file diff --git a/xBD_code/predict_test_cls.py b/xBD_code/predict_test_cls.py new file mode 100644 index 0000000..9d352d4 --- /dev/null +++ b/xBD_code/predict_test_cls.py @@ -0,0 +1,101 @@ +import os + +from os import path, makedirs, listdir +import sys +import numpy as np +np.random.seed(1) +import random +random.seed(1) + +import torch +from torch import nn +from torch.backends import cudnn + +from torch.autograd import Variable + +import pandas as pd +from tqdm import tqdm +import timeit +import cv2 + +from zoo.models import UNet_Change_Transformer_BiT + +from utils import * + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +test_dir = '/scratch/nka77/DATA/test/images' +models_folder = '/scratch/nka77/xview_first/weights' +pred_folder = '/scratch/nka77/xview_first/pred/' + +if __name__ == '__main__': + t0 = timeit.default_timer() + + pred_folder = pred_folder + 'unettransformerBottleneck_cls_' + makedirs(pred_folder, exist_ok=True) + + models = [] + for snapshot in ['unettransformerBottleneck_cls_0_0_iter10']: + snap_to_load = snapshot + model = UNet_Change_Transformer().cuda() + model = nn.DataParallel(model).cuda() + print("=> loading checkpoint '{}'".format(snap_to_load)) + checkpoint = torch.load(path.join(models_folder, snap_to_load), map_location='cpu') + loaded_dict = checkpoint['state_dict'] + sd = model.state_dict() + for k in model.state_dict(): + if k in loaded_dict and sd[k].size() == loaded_dict[k].size(): + sd[k] = loaded_dict[k] + loaded_dict = sd + model.load_state_dict(loaded_dict) + print("loaded checkpoint '{}' (epoch {}, best_score {})" + .format(snap_to_load, checkpoint['epoch'], checkpoint['best_score'])) + model.eval() + models.append(model) + + + with torch.no_grad(): + for f in tqdm(sorted(listdir(test_dir))): + if '_pre_' in f: + fn = path.join(test_dir, f) + img = cv2.imread(fn, cv2.IMREAD_COLOR) + img2 = cv2.imread(fn.replace('_pre_', '_post_'), cv2.IMREAD_COLOR) + if (img.shape != img2.shape): + continue + img = np.concatenate([img, img2], axis=2) + img = preprocess_inputs(img) + + inp = [] + inp.append(img) + inp.append(img[::-1, ...]) + inp.append(img[:, ::-1, ...]) + inp.append(img[::-1, ::-1, ...]) + inp = np.asarray(inp, dtype='float') + inp = torch.from_numpy(inp.transpose((0, 3, 1, 2))).float() + inp = Variable(inp).cuda() + + # img = torch.from_numpy(np.asarray(img.transpose(2, 0, 1), dtype='float')).unsqueeze(0).float().cuda() + pred = [] + # for i in range(4): + for model in models: + msk = model(inp) + msk = torch.sigmoid(msk) + msk = msk.cpu().numpy() + # pred.append(msk) + pred.append(msk[0, ...]) + pred.append(msk[1, :, ::-1, :]) + pred.append(msk[2, :, :, ::-1]) + pred.append(msk[3, :, ::-1, ::-1]) + + pred_full = np.asarray(pred).mean(axis=0) + + msk = pred_full * 255 + msk = msk.astype('uint8').transpose(1, 2, 0) + np.save(path.join(pred_folder, '{0}.png'.format(f.replace('.png', '_full.png'))), msk) + cv2.imwrite(path.join(pred_folder, '{0}.png'.format(f.replace('.png', '_part1.png'))), msk[..., :3], [cv2.IMWRITE_PNG_COMPRESSION, 9]) + cv2.imwrite(path.join(pred_folder, '{0}.png'.format(f.replace('.png', '_part2.png'))), msk[..., 2:], [cv2.IMWRITE_PNG_COMPRESSION, 9]) + + + elapsed = timeit.default_timer() - t0 + print('Time: {:.3f} min'.format(elapsed / 60)) diff --git a/xBD_code/train_GAN.py b/xBD_code/train_GAN.py new file mode 100644 index 0000000..98e6541 --- /dev/null +++ b/xBD_code/train_GAN.py @@ -0,0 +1,490 @@ +import os +os.environ["MKL_NUM_THREADS"] = "2" +os.environ["NUMEXPR_NUM_THREADS"] = "2" +os.environ["OMP_NUM_THREADS"] = "2" + +from os import path, makedirs, listdir +import sys +import numpy as np +np.random.seed(1) +import random +random.seed(1) + +import torch +from torch import nn +from torch.backends import cudnn +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +import torch.optim.lr_scheduler as lr_scheduler +#from apex import amp +import torch.cuda.amp as amp +from torchvision import transforms +import torchvision.transforms.functional as TF +import torch.nn.functional as F +from adamw import AdamW +from losses import dice_round, ComboLoss + +from tqdm import tqdm +import timeit +#import cv2 + +from zoo.models import BASE_Transformer, Res34_Unet_Double +from zoo.model_transformer_encoding import BASE_Transformer_UNet, Discriminator + + +from torch.autograd import Variable +from PIL import Image + +model = "TUNet" + +if model == "TUNet": + print("UNet Transformer") + model = BASE_Transformer_UNet(input_nc=3, output_nc=5, token_len=4, resnet_stages_num=4, + with_pos='learned', with_decoder_pos='learned', enc_depth=1, dec_depth=8).cuda() + snapshot_name = 'BASE_UNet_Transformer_img512_lossv7_AOI_GAN' + print("snapshot_name ", snapshot_name, "with seg and cls headers and ce loss only on building") + print("Loss only building patch lr:0.001 Seg weights: loss_seg = loss0") + print("CE weights_ = torch.tensor([0.001,0.10,1.5,1.5,1.5])") + print("upsampling 1:3 with 50%") + print("GAN Loss") + snap_to_load = 'res34_loc_0_1_best' + +elif model == "BiT": + print("BiT ....") + model = BASE_Transformer(input_nc=3, output_nc=5, token_len=4, resnet_stages_num=4, + with_pos='learned', enc_depth=1, dec_depth=8).cuda() + snapshot_name = 'BiT_lossv2' + print("snapshot_name ", snapshot_name) + print("Loss only building patch lr:0.001 Seg weights: loss_seg = loss0 ") + print("CE weights_ = torch.tensor([0.001,0.10,1.5,1.0,1.5])") + print("reduced upsampling of images 1 and 3") + snap_to_load = 'res34_loc_0_1_best' + +else: + model = Res34_Unet_Double().cuda() + snapshot_name = 'Res34_Unet_Double_img512_lossv5_AOI' + snap_to_load = 'res34_loc_0_1_best' + + +#from imgaug import augmenters as nniaa +from utils import * +#from scikitimage.morphology import square, dilation +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score + +import gc +torch.cuda.empty_cache() + +#cv2.setNumThreads(0) +#cv2.ocl.setUseOpenCL(False) + +train_dirs = ['../DATA/train', '../DATA/tier3', '../DATA/AOI3'] +models_folder = 'weights' + +input_shape = (1024,1024) +crop_size = 512 +_thr = 0.3 +cudnn.benchmark = True +batch_size = 8 +val_batch_size = 8 + +all_files = [] +for d in train_dirs: + for f in sorted(listdir(path.join(d, 'images'))): + if ('_pre_disaster.png' in f):# and (('hurricane-harvey' in f)):# | ('hurricane-michael' in f) | ('mexico-earthquake' in f) | ('tuscaloosa-tornado' in f)): + all_files.append(path.join(d, 'images', f)) + + +valid = Variable(torch.ones((batch_size, 1000)), requires_grad=False).cuda() +fake = Variable(torch.zeros((batch_size, 1000)), requires_grad=False).cuda() + + +class TrainData(Dataset): + def __init__(self, train_idxs): + super().__init__() + self.train_idxs = train_idxs + #self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2) + + def __len__(self): + return len(self.train_idxs) + + def __getitem__(self, idx): + _idx = self.train_idxs[idx] + + fn = all_files[_idx] + + img = np.array(Image.open(fn)) + img2 = np.array(Image.open(fn.replace('_pre_disaster', '_post_disaster'))) + + msk0 = np.array(Image.open(fn.replace('/images/', '/masks/'))) + lbl_msk1 = np.array(Image.open(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'))) + + x0 = random.randint(0, img.shape[1] - crop_size) + y0 = random.randint(0, img.shape[0] - crop_size) + + img1 = img[y0:y0+crop_size, x0:x0+crop_size, :] + img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] + msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] + lbl_msk1 = lbl_msk1[y0:y0+crop_size, x0:x0+crop_size] + + if random.random() > 0.7: + imgs = [img1, img2] + labels = [msk0, lbl_msk1] + imgs = [TF.to_pil_image(img) for img in imgs] + labels = [TF.to_pil_image(img) for img in labels] + + if random.random() > 0.3: + imgs = [TF.hflip(img) for img in imgs] + labels = [TF.hflip(img) for img in labels] + + if random.random() > 0.3: + imgs = [TF.vflip(img) for img in imgs] + labels = [TF.vflip(img) for img in labels] + + if random.random() > 0.3: + x = random.randint(0, 200) + y = random.randint(0, 200) + imgs = [TF.resized_crop(img, x, y, crop_size-x, crop_size-y, (crop_size,crop_size)) for img in imgs] + labels = [TF.resized_crop(img, x, y, crop_size-x, crop_size-y, (crop_size,crop_size)) for img in labels] + + if random.random() > 0.7: + imgs = [transforms.ColorJitter(brightness=[0.8,1.2], contrast=[0.8,1.2], saturation=[0.8,1.2])(img) for img in imgs] + + msk0, lbl_msk1 = np.array(labels[0]), np.array(labels[1]) + img1, img2 = np.array(imgs[0]), np.array(imgs[1]) + + msk1 = np.zeros_like(lbl_msk1) + msk2 = np.zeros_like(lbl_msk1) + msk3 = np.zeros_like(lbl_msk1) + msk4 = np.zeros_like(lbl_msk1) + msk2[lbl_msk1 == 2] = 255 + msk3[lbl_msk1 == 3] = 255 + msk4[lbl_msk1 == 4] = 255 + msk1[lbl_msk1 == 1] = 255 + + msk0 = msk0[..., np.newaxis] + msk1 = msk1[..., np.newaxis] + msk2 = msk2[..., np.newaxis] + msk3 = msk3[..., np.newaxis] + msk4 = msk4[..., np.newaxis] + + msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) + msk = (msk > 127) + + msk[..., 0] = False + '''msk[..., 1] = dilation(msk[..., 1], square(5)) + msk[..., 2] = dilation(msk[..., 2], square(5)) + msk[..., 3] = dilation(msk[..., 3], square(5)) + msk[..., 4] = dilation(msk[..., 4], square(5))''' + msk[..., 1][msk[..., 2:].max(axis=2)] = False + msk[..., 3][msk[..., 2]] = False + msk[..., 4][msk[..., 2]] = False + msk[..., 4][msk[..., 3]] = False + msk[..., 0][msk[..., 1:].max(axis=2)] = True + msk = msk * 1 + + lbl_msk = msk.argmax(axis=2) + + img = np.concatenate([img1, img2], axis=2) + img = preprocess_inputs(img) + + img = torch.tensor(img.transpose((2, 0, 1))).float() + msk = torch.tensor(msk.transpose((2, 0, 1))).long() + + sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn} + return sample + + +class ValData(Dataset): + def __init__(self, image_idxs): + super().__init__() + self.image_idxs = image_idxs + + def __len__(self): + return len(self.image_idxs) + + def __getitem__(self, idx): + _idx = self.image_idxs[idx] + + fn = all_files[_idx] + + img = np.array(Image.open(fn)) + img2 = np.array(Image.open(fn.replace('_pre_disaster', '_post_disaster'))) + + # msk_loc = cv2.imread(path.join(loc_folder, '{0}.png'.format(fn.split('/')[-1].replace('.png', '_part1.png'))), cv2.IMREAD_UNCHANGED) > (0.3*255) + + msk0 = np.array(Image.open(fn.replace('/images/', '/masks/'))) + lbl_msk1 = np.array(Image.open(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'))) + + x0 = 512 + y0 = 512 + + img = img[y0:y0+crop_size, x0:x0+crop_size, :] + img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] + msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] + lbl_msk1 = lbl_msk1[y0:y0+crop_size, x0:x0+crop_size] + + msk1 = np.zeros_like(lbl_msk1) + msk2 = np.zeros_like(lbl_msk1) + msk3 = np.zeros_like(lbl_msk1) + msk4 = np.zeros_like(lbl_msk1) + msk1[lbl_msk1 == 1] = 255 + msk2[lbl_msk1 == 2] = 255 + msk3[lbl_msk1 == 3] = 255 + msk4[lbl_msk1 == 4] = 255 + + msk0 = msk0[..., np.newaxis] + msk1 = msk1[..., np.newaxis] + msk2 = msk2[..., np.newaxis] + msk3 = msk3[..., np.newaxis] + msk4 = msk4[..., np.newaxis] + + msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) + msk = (msk > 127) + + msk = msk * 1 + + lbl_msk = msk[..., 1:].argmax(axis=2) + + img = np.concatenate([img, img2], axis=2) + img = preprocess_inputs(img) + + img = torch.tensor(img.transpose((2, 0, 1))).float() + msk = torch.tensor(msk.transpose((2, 0, 1))).long() + + sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn, 'msk_loc': msk} + return sample + + +def validate(model, data_loader): + dices0 = [] + + tp = np.zeros((4,)) + fp = np.zeros((4,)) + fn = np.zeros((4,)) + totalp = np.zeros((4,)) + + + data_loader = tqdm(data_loader) + with torch.no_grad(): + for i, sample in enumerate(data_loader): + msks = sample["msk"].numpy() + lbl_msk = sample["lbl_msk"].numpy() + imgs = sample["img"].cuda(non_blocking=True) + # msk_loc = sample["msk_loc"].numpy() * 1 + out = model(imgs) + + # msk_pred = msk_loc + msk_pred = torch.sigmoid(out).cpu().numpy()[:, 0, ...] + msk_damage_pred = torch.sigmoid(out).cpu().numpy()[:, 1:, ...] + + for j in range(msks.shape[0]): + dices0.append(dice(msks[j, 0], msk_pred[j] > _thr)) + targ = lbl_msk[j][lbl_msk[j, 0] > 0] + pred = msk_damage_pred[j].argmax(axis=0) + pred = pred * (msk_pred[j] > _thr) + pred = pred[lbl_msk[j, 0] > 0] + for c in range(4): + tp[c] += np.logical_and(pred == c, targ == c).sum() + fn[c] += np.logical_and(pred != c, targ == c).sum() + fp[c] += np.logical_and(pred == c, targ != c).sum() + totalp += (targ == c).sum() + + d0 = np.mean(dices0) + + f1_sc = np.zeros((4,)) + for c in range(4): + f1_sc[c] = 2 * tp[c] / (2 * tp[c] + fp[c] + fn[c]) + f1 = 4 / np.sum(1.0 / (f1_sc + 1e-6)) + + sc = 0.3 * d0 + 0.7 * f1 + print("Val Score: {}, Dice: {}, F1: {}, F1_0: {}, F1_1: {}, F1_2: {}, F1_3: {}".format(sc, d0, f1, f1_sc[0], f1_sc[1], f1_sc[2], f1_sc[3])) + return sc + + +def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch): + model = model.eval() + d = validate(model, data_loader=data_val) + + if d > best_score: + torch.save({ + 'epoch': current_epoch + 1, + 'state_dict': model.state_dict(), + 'best_score': d, + 'optimizer' : optimizer.state_dict(), + }, path.join(models_folder, snapshot_name)) + best_score = d + + print("score: {}\tscore_best: {}".format(d, best_score)) + return best_score + + +def train_epoch(current_epoch, model, discriminator, optimizer, scheduler, train_data_loader): + losses = AverageMeter() + losses1 = AverageMeter() + losses2 = AverageMeter() + lossesgan = AverageMeter() + + seg_loss = ComboLoss({'dice': 1, 'focal': 8}, per_image=False).cuda() + weights_ = torch.tensor([0.001, 0.10,1.5,1.5,1.5]) + ce_loss = nn.CrossEntropyLoss(weight=weights_).cuda() + gan_loss = nn.BCEWithLogitsLoss().cuda() + + # weights_ = torch.tensor([0.10,1.,1., 1.]) + # ce_loss_avg = nn.CrossEntropyLoss(weight=weights_).cuda() + # ce_loss = nn.CrossEntropyLoss().cuda() + + + iterator = tqdm(train_data_loader) + # iterator = train_data_loader + model.train() + for i, sample in enumerate(iterator): + imgs = sample["img"].cuda(non_blocking=True) + msks = sample["msk"].cuda(non_blocking=True) + lbl_msk = sample["lbl_msk"].cuda(non_blocking=True) + + #### GENERATOR ### + model.zero_grad() + out = model(imgs) + + if (i % 8 == 0): + #### DISCRIMININATOR ### + discriminator.zero_grad() + + msks = msks.to(torch.float32) + true_label = discriminator(msks) + loss_gan_1 = gan_loss(true_label, valid) + fake_label = discriminator(out.detach()) + loss_gan_0 = gan_loss(fake_label, fake) + loss_D = 0.1*(loss_gan_1 + loss_gan_0)/2 + loss_D.backward() + d_optimizer.step() + + #### GENERATOR ### + loss_seg = seg_loss(out[:, 0, ...], msks[:, 0, ...]) + + msks[:, 0, ...] = 1 - msks[:, 0, ...] + lbl_msk = torch.argmax(msks, dim=1) + loss_cls = ce_loss(out, lbl_msk) * 5 + + # out_ordinal = torch.argmax(out, dim=1) + # lbl_msk = lbl_msk.float() + # loss_ordinal = nn.MSELoss()(out_ordinal, lbl_msk)/5 + + gen_label = discriminator(out.detach()) + loss_gan = gan_loss(gen_label, valid) + loss_G = loss_seg + loss_cls + 0.01*loss_gan + + loss_G.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.999) + optimizer.step() + + losses.update(loss_G.item(), imgs.size(0)) + losses1.update(loss_D.item(), imgs.size(0)) #loss5 + losses2.update(loss_cls.item(), imgs.size(0)) + lossesgan.update(loss_gan, imgs.size(0)) + + iterator.set_description( + "epoch: {}; lr {:.7f}; Loss {loss.val:.4f}; loss_D {loss1.val:.4f}; loss_cls {loss2.avg:.4f}; loss_gan {dice.val:.4f}".format( + current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, dice=lossesgan)) + + + scheduler.step(current_epoch) + print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; loss2 {loss1.avg:.4f}; Dice {dice.avg:.4f}".format( + current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, dice=lossesgan)) + + +if __name__ == '__main__': + t0 = timeit.default_timer() + + makedirs(models_folder, exist_ok=True) + seed = 0 + + file_classes = [] + AOI_files = [] + for fn in tqdm(all_files): + fl = np.zeros((4,), dtype=bool) + msk1 = np.array(Image.open(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'))) + for c in range(1, 5): + fl[c-1] = c in msk1 + file_classes.append(fl) + if 'AOI' in fn: + file_classes.append(fl) + file_classes = np.asarray(file_classes) + + train_idxs0, val_idxs = train_test_split(np.arange(len(all_files)), test_size=0.1, random_state=seed) + + np.random.seed(seed + 321) + random.seed(seed + 321) + + train_idxs = [] + non_zero_bldg = 0 + non_zero_dmg = 0 + for i in train_idxs0: + if file_classes[i, :].max(): + train_idxs.append(i) + non_zero_bldg += 1 + if (random.random() > 0.5) and file_classes[i, 1:].max(): + train_idxs.append(i) + non_zero_dmg += 1 + # if (random.random() > 0.7) and file_classes[i, 3].max(): + # train_idxs.append(i) + # non_zero_dmg += 1 + + train_idxs = np.asarray(train_idxs) + steps_per_epoch = len(train_idxs) // batch_size + validation_steps = len(val_idxs) // val_batch_size + print(non_zero_bldg, non_zero_dmg, len(train_idxs), len(val_idxs)) + print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps) + + data_train = TrainData(train_idxs) + val_train = ValData(val_idxs) + + discriminator = Discriminator().cuda() + + train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=8, shuffle=True, pin_memory=False, drop_last=True) + val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=8, shuffle=False, pin_memory=False) + + optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=1e-6) + d_optimizer = AdamW(discriminator.parameters(), lr=0.0001, weight_decay=1e-6) + + scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190], gamma=0.6) + + # snap_to_load = 'res34_loc_{}_1_best'.format(seed) + ''' + print("=> loading checkpoint '{}'".format(snap_to_load)) + checkpoint = torch.load(path.join(models_folder, snap_to_load), map_location='cpu') + loaded_dict = checkpoint['state_dict'] + sd = model.state_dict() + for k in model.state_dict(): + if k in loaded_dict and sd[k].size() == loaded_dict[k].size(): + sd[k] = loaded_dict[k] + loaded_dict = sd + model.load_state_dict(loaded_dict) + print("loaded checkpoint '{}' (epoch {}, best_score {})" + .format(snap_to_load, checkpoint['epoch'], checkpoint['best_score'])) + del loaded_dict + del sd + del checkpoint +''' + gc.collect() + torch.cuda.empty_cache() + + model = nn.DataParallel(model).cuda() + # discriminator = nn.DataParallel(discriminator) + + best_score = 0 + torch.cuda.empty_cache() + + scaler = amp.GradScaler() + + for epoch in range(100): + train_epoch(epoch, model, discriminator, optimizer, scheduler, train_data_loader) + if epoch % 2 == 0: + torch.cuda.empty_cache() + best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch) + + elapsed = timeit.default_timer() - t0 + torch.cuda.empty_cache() + print('Time: {:.3f} min'.format(elapsed / 60)) + diff --git a/xBD_code/train_dual_hrnet.py b/xBD_code/train_dual_hrnet.py new file mode 100644 index 0000000..831f995 --- /dev/null +++ b/xBD_code/train_dual_hrnet.py @@ -0,0 +1,467 @@ +import os +os.environ["MKL_NUM_THREADS"] = "2" +os.environ["NUMEXPR_NUM_THREADS"] = "2" +os.environ["OMP_NUM_THREADS"] = "2" + +from os import path, makedirs, listdir +import sys +import numpy as np +np.random.seed(1) +import random +random.seed(1) + +import torch +from torch import nn +from torch.backends import cudnn +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +import torch.optim.lr_scheduler as lr_scheduler +#from apex import amp +import torch.cuda.amp as amp +from torchvision import transforms +import torchvision.transforms.functional as TF +from adamw import AdamW +from losses import dice_round, ComboLoss + +from tqdm import tqdm +import timeit +import cv2 + +import torch.nn.functional as F + +from zoo.models import BASE_Transformer, Res34_Unet_Double +from zoo.model_transformer_encoding import BASE_Transformer_UNet +from dual_hrnet import DualHRNet, get_model +from yacs.config import CfgNode + +from PIL import Image + +model = "dual_hrnet" + +if model == "dual_hrnet": + print("Dual HRNET") + with open("dual_hrnet_config.yaml", 'rb') as fp: + config = CfgNode.load_cfg(fp) + model = get_model(config) + snapshot_name = 'dual_hrnet' + print("snapshot_name ", snapshot_name, "with seg and cls headers and ce loss only on building") + # snap_to_load = 'res34_loc_0_1_best' + +else: + model = Res34_Unet_Double().cuda() + snapshot_name = 'Res34_Unet_Double' + snap_to_load = 'res34_loc_0_1_best' + + +from imgaug import augmenters as iaa +from utils import * +from skimage.morphology import square, dilation +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score + +import gc +torch.cuda.empty_cache() + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +train_dirs = ['/scratch/nka77/DATA/train', '/scratch/nka77/DATA/tier3'] +models_folder = '/scratch/nka77/xview_first/weights' + +loc_folder = '/scratch/nka77/xview_first/pred/pred34_loc_val' + +input_shape = (1024,1024) +crop_size = 512 + +all_files = [] +for d in train_dirs: + for f in sorted(listdir(path.join(d, 'images'))): + if ('_pre_disaster.png' in f): + all_files.append(path.join(d, 'images', f)) + + +class TrainData(Dataset): + def __init__(self, train_idxs): + super().__init__() + self.train_idxs = train_idxs + self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2) + + def __len__(self): + return len(self.train_idxs) + + def __getitem__(self, idx): + _idx = self.train_idxs[idx] + + fn = all_files[_idx] + + img = cv2.imread(fn, cv2.IMREAD_COLOR) + img2 = cv2.imread(fn.replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_COLOR) + + msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED) + lbl_msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) + + x0 = random.randint(0, img.shape[1] - crop_size) + y0 = random.randint(0, img.shape[0] - crop_size) + + img1 = img[y0:y0+crop_size, x0:x0+crop_size, :] + img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] + msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] + lbl_msk1 = lbl_msk1[y0:y0+crop_size, x0:x0+crop_size] + + if random.random() > 0.7: + imgs = [img1, img2] + labels = [msk0, lbl_msk1] + imgs = [TF.to_pil_image(img) for img in imgs] + labels = [TF.to_pil_image(img) for img in labels] + + if random.random() > 0.3: + imgs = [TF.hflip(img) for img in imgs] + labels = [TF.hflip(img) for img in labels] + + if random.random() > 0.3: + imgs = [TF.vflip(img) for img in imgs] + labels = [TF.vflip(img) for img in labels] + + if random.random() > 0.5: + imgs = [transforms.ColorJitter(brightness=[0.8,1.2], contrast=[0.8,1.2], saturation=[0.8,1.2])(img) for img in imgs] + + msk0, lbl_msk1 = np.array(labels[0]), np.array(labels[1]) + img1, img2 = np.array(imgs[0]), np.array(imgs[1]) + + msk1 = np.zeros_like(lbl_msk1) + msk2 = np.zeros_like(lbl_msk1) + msk3 = np.zeros_like(lbl_msk1) + msk4 = np.zeros_like(lbl_msk1) + msk2[lbl_msk1 == 2] = 255 + msk3[lbl_msk1 == 3] = 255 + msk4[lbl_msk1 == 4] = 255 + msk1[lbl_msk1 == 1] = 255 + + msk0 = msk0[..., np.newaxis] + msk1 = msk1[..., np.newaxis] + msk2 = msk2[..., np.newaxis] + msk3 = msk3[..., np.newaxis] + msk4 = msk4[..., np.newaxis] + + msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) + msk = (msk > 127) + + msk[..., 0] = False + msk[..., 1] = dilation(msk[..., 1], square(5)) + msk[..., 2] = dilation(msk[..., 2], square(5)) + msk[..., 3] = dilation(msk[..., 3], square(5)) + msk[..., 4] = dilation(msk[..., 4], square(5)) + msk[..., 1][msk[..., 2:].max(axis=2)] = False + msk[..., 3][msk[..., 2]] = False + msk[..., 4][msk[..., 2]] = False + msk[..., 4][msk[..., 3]] = False + msk[..., 0][msk[..., 1:].max(axis=2)] = True + msk = msk * 1 + + lbl_msk = msk.argmax(axis=2) + + img = np.concatenate([img1, img2], axis=2) + img = preprocess_inputs(img) + + img = torch.tensor(img.transpose((2, 0, 1))).float() + msk = torch.tensor(msk.transpose((2, 0, 1))).long() + + sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn} + return sample + + +class ValData(Dataset): + def __init__(self, image_idxs): + super().__init__() + self.image_idxs = image_idxs + + def __len__(self): + return len(self.image_idxs) + + def __getitem__(self, idx): + _idx = self.image_idxs[idx] + + fn = all_files[_idx] + + img = cv2.imread(fn, cv2.IMREAD_COLOR) + img2 = cv2.imread(fn.replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_COLOR) + + # msk_loc = cv2.imread(path.join(loc_folder, '{0}.png'.format(fn.split('/')[-1].replace('.png', '_part1.png'))), cv2.IMREAD_UNCHANGED) > (0.3*255) + + msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED) + lbl_msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) + + x0 = 512 + y0 = 512 + + img = img[y0:y0+crop_size, x0:x0+crop_size, :] + img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] + msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] + lbl_msk1 = lbl_msk1[y0:y0+crop_size, x0:x0+crop_size] + + msk1 = np.zeros_like(lbl_msk1) + msk2 = np.zeros_like(lbl_msk1) + msk3 = np.zeros_like(lbl_msk1) + msk4 = np.zeros_like(lbl_msk1) + msk1[lbl_msk1 == 1] = 255 + msk2[lbl_msk1 == 2] = 255 + msk3[lbl_msk1 == 3] = 255 + msk4[lbl_msk1 == 4] = 255 + + msk0 = msk0[..., np.newaxis] + msk1 = msk1[..., np.newaxis] + msk2 = msk2[..., np.newaxis] + msk3 = msk3[..., np.newaxis] + msk4 = msk4[..., np.newaxis] + + msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) + msk = (msk > 127) + + msk = msk * 1 + + lbl_msk = msk[..., 1:].argmax(axis=2) + + img = np.concatenate([img, img2], axis=2) + img = preprocess_inputs(img) + + img = torch.from_numpy(img.transpose((2, 0, 1))).float() + msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() + + sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn, 'msk_loc': msk} + return sample + + +def validate(model, data_loader): + dices0 = [] + + tp = np.zeros((4,)) + fp = np.zeros((4,)) + fn = np.zeros((4,)) + totalp = np.zeros((4,)) + + _thr = 0.3 + data_loader = tqdm(data_loader) + with torch.no_grad(): + for i, sample in enumerate(data_loader): + msks = sample["msk"].numpy() + lbl_msk = sample["lbl_msk"].numpy() + imgs = sample["img"].cuda(non_blocking=True) + # msk_loc = sample["msk_loc"].numpy() * 1 + out = model(imgs) + + out_loc = out['loc'] + out_cls = out['cls'] + out_loc = F.interpolate(out_loc, size=(lbl_msk.shape[-1],lbl_msk.shape[-1])) + out_cls = F.interpolate(out_cls, size=(lbl_msk.shape[-1],lbl_msk.shape[-1])) + + out_loc = out_loc.argmax(axis=1) + + msk_pred = torch.sigmoid(out_loc).cpu().numpy() + msk_damage_pred = torch.sigmoid(out_cls).cpu().numpy() + + for j in range(msks.shape[0]): + dices0.append(dice(msks[j, 0], msk_pred[j] > _thr)) + targ = lbl_msk[j][lbl_msk[j, 0] > 0] + pred = msk_damage_pred[j].argmax(axis=0) + pred = pred * (msk_pred[j] > _thr) + pred = pred[lbl_msk[j, 0] > 0] + for c in range(4): + tp[c] += np.logical_and(pred == c, targ == c).sum() + fn[c] += np.logical_and(pred != c, targ == c).sum() + fp[c] += np.logical_and(pred == c, targ != c).sum() + totalp += (targ == c).sum() + + d0 = np.mean(dices0) + + f1_sc = np.zeros((4,)) + for c in range(4): + f1_sc[c] = 2 * tp[c] / (2 * tp[c] + fp[c] + fn[c]) + f1 = 4 / np.sum(1.0 / (f1_sc + 1e-6)) + + f1_sc_wt = np.zeros((4,)) + totalp = totalp/sum(totalp) + for c in range(4): + f1_sc_wt[c] = totalp[c] * 2 * tp[c] / (2 * tp[c] + fp[c] + fn[c]) + f1_wt = 1 / np.sum(1.0 / (f1_sc_wt + 1e-6)) + + sc = 0.3 * d0 + 0.7 * f1 + print("Val Score: {}, Dice: {}, F1: {}, F1wt: {}, F1_0: {}, F1_1: {}, F1_2: {}, F1_3: {}".format(sc, d0, f1, f1_wt, f1_sc[0], f1_sc[1], f1_sc[2], f1_sc[3])) + return sc + + +def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch): + model = model.eval() + d = validate(model, data_loader=data_val) + + if d > best_score: + torch.save({ + 'epoch': current_epoch + 1, + 'state_dict': model.state_dict(), + 'best_score': d, + 'optimizer' : optimizer.state_dict(), + }, path.join(models_folder, snapshot_name)) + best_score = d + + print("score: {}\tscore_best: {}".format(d, best_score)) + return best_score + + +def train_epoch(current_epoch, seg_loss, ce_loss, model, optimizer, scheduler, train_data_loader): + losses = AverageMeter() + losses1 = AverageMeter() + losses2 = AverageMeter() + dices = AverageMeter() + + iterator = tqdm(train_data_loader) + # iterator = train_data_loader + + seg_loss = ComboLoss({'dice': 1, 'focal': 8}, per_image=False).cuda() + weights_ = torch.tensor([0.10,2.,1, 2]) + ce_loss = nn.CrossEntropyLoss(weight=weights_).cuda() + + model.train() + for i, sample in enumerate(iterator): + imgs = sample["img"].cuda(non_blocking=True) + msks = sample["msk"].cuda(non_blocking=True) + lbl_msk = sample["lbl_msk"].cuda(non_blocking=True) + + with amp.autocast(): + out = model(imgs) + + out_loc = out['loc'] + out_cls = out['cls'] + out_loc = F.interpolate(out_loc, size=(lbl_msk.shape[-1], lbl_msk.shape[-1])) + out_cls = F.interpolate(out_cls, size=(lbl_msk.shape[-1], lbl_msk.shape[-1])) + + out_loc = out_loc.argmax(axis=1) + loss0 = seg_loss(out_loc, msks[:, 0, ...]) + + loss_seg = loss0 + + true_bldg = torch.argmax(msks[:,1:, ...], dim=1) + # out_cls = out_cls[:,1:, ...] + # print(true_bldg.shape, out_cls.shape) + loss_cls = ce_loss(out_cls, true_bldg) * 5 + + loss = loss_seg + loss_cls + + with torch.no_grad(): + _probs = torch.sigmoid(out_cls[:, 0, ...]) + dice_sc = 1 - dice_round(_probs, msks[:, 0, ...]) + + losses.update(loss.item(), imgs.size(0)) + losses1.update(loss_seg.item(), imgs.size(0)) #loss5 + losses2.update(loss_cls.item(), imgs.size(0)) + + dices.update(dice_sc, imgs.size(0)) + + iterator.set_description( + "epoch: {}; lr {:.7f}; Loss {loss.val:.4f}; loss_seg {loss1.avg:.4f}; loss_cls {loss2.avg:.4f}; Dice {dice.val:.4f}".format( + current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, dice=dices)) + + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + scheduler.step(current_epoch) + + print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; loss2 {loss1.avg:.4f}; Dice {dice.avg:.4f}".format( + current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, dice=dices)) + + + +if __name__ == '__main__': + t0 = timeit.default_timer() + + makedirs(models_folder, exist_ok=True) + seed = 0 + + cudnn.benchmark = True + batch_size = 4 + val_batch_size = 4 + + file_classes = [] + for fn in tqdm(all_files): + fl = np.zeros((4,), dtype=bool) + msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) + for c in range(1, 5): + fl[c-1] = c in msk1 + file_classes.append(fl) + file_classes = np.asarray(file_classes) + + train_idxs0, val_idxs = train_test_split(np.arange(len(all_files)), test_size=0.1, random_state=seed) + + np.random.seed(seed + 321) + random.seed(seed + 321) + + train_idxs = [] + non_zero_bldg = 0 + non_zero_dmg = 0 + for i in train_idxs0: + if file_classes[i, :].max(): + train_idxs.append(i) + non_zero_bldg += 1 + if (random.random() > 0.7) and file_classes[i, 1].max(): + train_idxs.append(i) + non_zero_dmg += 1 + if (random.random() > 0.7) and file_classes[i, 3].max(): + train_idxs.append(i) + non_zero_dmg += 1 + + train_idxs = np.asarray(train_idxs) + steps_per_epoch = len(train_idxs) // batch_size + validation_steps = len(val_idxs) // val_batch_size + + print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps) + + data_train = TrainData(train_idxs) + val_train = ValData(val_idxs) + + train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=6, shuffle=True, pin_memory=False, drop_last=True) + val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=6, shuffle=False, pin_memory=False) + + params = model.parameters() + optimizer = AdamW(params, lr=0.001, weight_decay=1e-6) + + scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190], gamma=0.6) + + # print("=> loading checkpoint '{}'".format(snap_to_load)) + # checkpoint = torch.load(path.join(models_folder, snap_to_load), map_location='cpu') + # loaded_dict = checkpoint['state_dict'] + # sd = model.state_dict() + # for k in model.state_dict(): + # if k in loaded_dict and sd[k].size() == loaded_dict[k].size(): + # sd[k] = loaded_dict[k] + # loaded_dict = sd + # model.load_state_dict(loaded_dict) + # print("loaded checkpoint '{}' (epoch {}, best_score {})" + # .format(snap_to_load, checkpoint['epoch'], checkpoint['best_score'])) + # del loaded_dict + # del sd + # del checkpoint + + gc.collect() + torch.cuda.empty_cache() + + model = nn.DataParallel(model).cuda() + + seg_loss = ComboLoss({'dice': 1, 'focal': 8}, per_image=False).cuda() + weights_ = torch.tensor([0.10,1.,0.80,1.]) + ce_loss = nn.CrossEntropyLoss(weight=weights_).cuda() + + best_score = 0 + torch.cuda.empty_cache() + + scaler = amp.GradScaler() + + for epoch in range(100): + train_epoch(epoch, seg_loss, ce_loss, model, optimizer, scheduler, train_data_loader) + if epoch % 2 == 0: + torch.cuda.empty_cache() + best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch) + + elapsed = timeit.default_timer() - t0 + torch.cuda.empty_cache() + print('Time: {:.3f} min'.format(elapsed / 60)) + diff --git a/xBD_code/train_unettransformer.py b/xBD_code/train_unettransformer.py new file mode 100644 index 0000000..973f833 --- /dev/null +++ b/xBD_code/train_unettransformer.py @@ -0,0 +1,580 @@ +import os +os.environ["MKL_NUM_THREADS"] = "2" +os.environ["NUMEXPR_NUM_THREADS"] = "2" +os.environ["OMP_NUM_THREADS"] = "2" + +from os import path, makedirs, listdir +import sys +import numpy as np +np.random.seed(1) +import random +random.seed(1) + +import torch +from torch import nn +from torch.backends import cudnn +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +import torch.optim.lr_scheduler as lr_scheduler +#from apex import amp +import torch.cuda.amp as amp + +from adamw import AdamW +from losses import dice_round, ComboLoss + +from tqdm import tqdm +import timeit +import cv2 + +from zoo.models import BASE_Transformer, Res34_Unet_Double +from zoo.model_transformer_encoding import BASE_Transformer_UNet + +model = "TUNet" + +if model == "TUNet": + print("UNet Transformer") + model = BASE_Transformer_UNet(input_nc=3, output_nc=5, token_len=4, resnet_stages_num=4, + with_pos='learned', enc_depth=1, dec_depth=8).cuda() + snapshot_name = 'BASE_UNet_Transformer_F1update' + snap_to_load = 'res34_loc_0_1_best' +elif model == "BiT": + print("BiT ....") + model = BASE_Transformer(input_nc=3, output_nc=5, token_len=4, resnet_stages_num=4, + with_pos='learned', enc_depth=1, dec_depth=8).cuda() + snapshot_name = 'BiT_F1update' + snap_to_load = 'res34_loc_0_1_best' +else: + model = Res34_Unet_Double().cuda() + snapshot_name = 'Res34_Unet_Double' + snap_to_load = 'res34_loc_0_1_best' + + +from imgaug import augmenters as iaa +from utils import * +from skimage.morphology import square, dilation +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score + +import gc +torch.cuda.empty_cache() + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + + + +# train_dirs = ['/scratch/nka77/DATA/AOI3', '/scratch/nka77/DATA/train', '/scratch/nka77/DATA/tier3'] +train_dirs = ['/scratch/nka77/DATA/train', '/scratch/nka77/DATA/tier3'] +models_folder = '/scratch/nka77/xview_first/weights' + +loc_folder = '/scratch/nka77/xview_first/pred/pred34_loc_val' + +input_shape = (1024,1024) +# input_shape = (608, 608) +crop_size = 256 + +all_files = [] +for d in train_dirs: + for f in sorted(listdir(path.join(d, 'images'))): + if ('_pre_disaster.png' in f) and (('hurricane-harvey' in f) | ('hurricane-michael' in f) | ('mexico-earthquake' in f) | ('tuscaloosa-tornado' in f)): + all_files.append(path.join(d, 'images', f)) + +# #aoi_files = [] +# aoi_dir = '/scratch/nka77/DATA/AOI3' +# for f in sorted(listdir(path.join(aoi_dir, 'images'))): +# if ('_pre_disaster.png' in f): +# all_files.append(path.join(aoi_dir, 'images', f)) +# #print(aoi_files) + +class TrainData(Dataset): + def __init__(self, train_idxs): + super().__init__() + self.train_idxs = train_idxs + self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2) + + def __len__(self): + return len(self.train_idxs) + + def __getitem__(self, idx): + _idx = self.train_idxs[idx] + + fn = all_files[_idx] + + img = cv2.imread(fn, cv2.IMREAD_COLOR) + img2 = cv2.imread(fn.replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_COLOR) + + msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED) + lbl_msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) + + x0 = random.randint(0, img.shape[1] - crop_size) + y0 = random.randint(0, img.shape[0] - crop_size) + + img = img[y0:y0+crop_size, x0:x0+crop_size, :] + img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] + msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] + lbl_msk1 = lbl_msk1[y0:y0+crop_size, x0:x0+crop_size] + + msk1 = np.zeros_like(lbl_msk1) + msk2 = np.zeros_like(lbl_msk1) + msk3 = np.zeros_like(lbl_msk1) + msk4 = np.zeros_like(lbl_msk1) + msk2[lbl_msk1 == 2] = 255 + msk3[lbl_msk1 == 3] = 255 + msk4[lbl_msk1 == 4] = 255 + msk1[lbl_msk1 == 1] = 255 + + if random.random() > 0.5: + img = img[::-1, ...] + img2 = img2[::-1, ...] + msk0 = msk0[::-1, ...] + msk1 = msk1[::-1, ...] + msk2 = msk2[::-1, ...] + msk3 = msk3[::-1, ...] + msk4 = msk4[::-1, ...] + + if random.random() > 0.05: + rot = random.randrange(4) + if rot > 0: + img = np.rot90(img, k=rot) + img2 = np.rot90(img2, k=rot) + msk0 = np.rot90(msk0, k=rot) + msk1 = np.rot90(msk1, k=rot) + msk2 = np.rot90(msk2, k=rot) + msk3 = np.rot90(msk3, k=rot) + msk4 = np.rot90(msk4, k=rot) + + if random.random() > 0.9: + shift_pnt = (random.randint(-320, 320), random.randint(-320, 320)) + img = shift_image(img, shift_pnt) + img2 = shift_image(img2, shift_pnt) + msk0 = shift_image(msk0, shift_pnt) + msk1 = shift_image(msk1, shift_pnt) + msk2 = shift_image(msk2, shift_pnt) + msk3 = shift_image(msk3, shift_pnt) + msk4 = shift_image(msk4, shift_pnt) + + if random.random() > 0.6: + rot_pnt = (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320)) + scale = 0.9 + random.random() * 0.2 + angle = random.randint(0, 20) - 10 + if (angle != 0) or (scale != 1): + img = rotate_image(img, angle, scale, rot_pnt) + img2 = rotate_image(img2, angle, scale, rot_pnt) + msk0 = rotate_image(msk0, angle, scale, rot_pnt) + msk1 = rotate_image(msk1, angle, scale, rot_pnt) + msk2 = rotate_image(msk2, angle, scale, rot_pnt) + msk3 = rotate_image(msk3, angle, scale, rot_pnt) + msk4 = rotate_image(msk4, angle, scale, rot_pnt) + + '''crop_size = input_shape[0] + if random.random() > 0.2: + crop_size = random.randint(int(input_shape[0] / 1.15), int(input_shape[0] / 0.85)) + + bst_x0 = random.randint(0, img.shape[1] - crop_size) + bst_y0 = random.randint(0, img.shape[0] - crop_size) + bst_sc = -1 + try_cnt = random.randint(1, 10) + for i in range(try_cnt): + x0 = random.randint(0, img.shape[1] - crop_size) + y0 = random.randint(0, img.shape[0] - crop_size) + _sc = msk2[y0:y0+crop_size, x0:x0+crop_size].sum() * 5 + msk3[y0:y0+crop_size, x0:x0+crop_size].sum() * 5 + msk4[y0:y0+crop_size, x0:x0+crop_size].sum() * 2 + msk1[y0:y0+crop_size, x0:x0+crop_size].sum() + if _sc > bst_sc: + bst_sc = _sc + bst_x0 = x0 + bst_y0 = y0 + x0 = bst_x0 + y0 = bst_y0 + img = img[y0:y0+crop_size, x0:x0+crop_size, :] + img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] + msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] + msk1 = msk1[y0:y0+crop_size, x0:x0+crop_size] + msk2 = msk2[y0:y0+crop_size, x0:x0+crop_size] + msk3 = msk3[y0:y0+crop_size, x0:x0+crop_size] + msk4 = msk4[y0:y0+crop_size, x0:x0+crop_size] + + + if crop_size != input_shape[0]: + img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, input_shape, interpolation=cv2.INTER_LINEAR) + msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR) + msk1 = cv2.resize(msk1, input_shape, interpolation=cv2.INTER_LINEAR) + msk2 = cv2.resize(msk2, input_shape, interpolation=cv2.INTER_LINEAR) + msk3 = cv2.resize(msk3, input_shape, interpolation=cv2.INTER_LINEAR) + msk4 = cv2.resize(msk4, input_shape, interpolation=cv2.INTER_LINEAR)''' + + + if random.random() > 0.985: + img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + elif random.random() > 0.985: + img2 = shift_channels(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + + if random.random() > 0.985: + img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + elif random.random() > 0.985: + img2 = change_hsv(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5)) + + if random.random() > 0.98: + if random.random() > 0.985: + img = clahe(img) + elif random.random() > 0.985: + img = gauss_noise(img) + elif random.random() > 0.985: + img = cv2.blur(img, (3, 3)) + elif random.random() > 0.98: + if random.random() > 0.985: + img = saturation(img, 0.9 + random.random() * 0.2) + elif random.random() > 0.985: + img = brightness(img, 0.9 + random.random() * 0.2) + elif random.random() > 0.985: + img = contrast(img, 0.9 + random.random() * 0.2) + + if random.random() > 0.98: + if random.random() > 0.985: + img2 = clahe(img2) + elif random.random() > 0.985: + img2 = gauss_noise(img2) + elif random.random() > 0.985: + img2 = cv2.blur(img2, (3, 3)) + elif random.random() > 0.98: + if random.random() > 0.985: + img2 = saturation(img2, 0.9 + random.random() * 0.2) + elif random.random() > 0.985: + img2 = brightness(img2, 0.9 + random.random() * 0.2) + elif random.random() > 0.985: + img2 = contrast(img2, 0.9 + random.random() * 0.2) + + + if random.random() > 0.983: + el_det = self.elastic.to_deterministic() + img = el_det.augment_image(img) + + if random.random() > 0.983: + el_det = self.elastic.to_deterministic() + img2 = el_det.augment_image(img2) + + msk0 = msk0[..., np.newaxis] + msk1 = msk1[..., np.newaxis] + msk2 = msk2[..., np.newaxis] + msk3 = msk3[..., np.newaxis] + msk4 = msk4[..., np.newaxis] + + msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) + msk = (msk > 127) + + msk[..., 0] = False + msk[..., 1] = dilation(msk[..., 1], square(5)) + msk[..., 2] = dilation(msk[..., 2], square(5)) + msk[..., 3] = dilation(msk[..., 3], square(5)) + msk[..., 4] = dilation(msk[..., 4], square(5)) + msk[..., 1][msk[..., 2:].max(axis=2)] = False + msk[..., 3][msk[..., 2]] = False + msk[..., 4][msk[..., 2]] = False + msk[..., 4][msk[..., 3]] = False + msk[..., 0][msk[..., 1:].max(axis=2)] = True + msk = msk * 1 + + lbl_msk = msk.argmax(axis=2) + + img = np.concatenate([img, img2], axis=2) + img = preprocess_inputs(img) + + img = torch.from_numpy(img.transpose((2, 0, 1))).float() + msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() + + sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn} + return sample + + +class ValData(Dataset): + def __init__(self, image_idxs): + super().__init__() + self.image_idxs = image_idxs + + def __len__(self): + return len(self.image_idxs) + + def __getitem__(self, idx): + _idx = self.image_idxs[idx] + + fn = all_files[_idx] + + img = cv2.imread(fn, cv2.IMREAD_COLOR) + img2 = cv2.imread(fn.replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_COLOR) + + # msk_loc = cv2.imread(path.join(loc_folder, '{0}.png'.format(fn.split('/')[-1].replace('.png', '_part1.png'))), cv2.IMREAD_UNCHANGED) > (0.3*255) + + msk0 = cv2.imread(fn.replace('/images/', '/masks/'), cv2.IMREAD_UNCHANGED) + lbl_msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) + + # x0 = random.randint(0, img.shape[1] - crop_size) + # y0 = random.randint(0, img.shape[0] - crop_size) + x0 = 512 + y0 = 512 + + img = img[y0:y0+crop_size, x0:x0+crop_size, :] + img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :] + msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size] + lbl_msk1 = lbl_msk1[y0:y0+crop_size, x0:x0+crop_size] + # msk_loc = msk_loc[y0:y0+crop_size, x0:x0+crop_size] + + msk1 = np.zeros_like(lbl_msk1) + msk2 = np.zeros_like(lbl_msk1) + msk3 = np.zeros_like(lbl_msk1) + msk4 = np.zeros_like(lbl_msk1) + msk1[lbl_msk1 == 1] = 255 + msk2[lbl_msk1 == 2] = 255 + msk3[lbl_msk1 == 3] = 255 + msk4[lbl_msk1 == 4] = 255 + + msk0 = msk0[..., np.newaxis] + msk1 = msk1[..., np.newaxis] + msk2 = msk2[..., np.newaxis] + msk3 = msk3[..., np.newaxis] + msk4 = msk4[..., np.newaxis] + + msk = np.concatenate([msk0, msk1, msk2, msk3, msk4], axis=2) + msk = (msk > 127) + + msk = msk * 1 + + lbl_msk = msk[..., 1:].argmax(axis=2) + + img = np.concatenate([img, img2], axis=2) + img = preprocess_inputs(img) + + img = torch.from_numpy(img.transpose((2, 0, 1))).float() + msk = torch.from_numpy(msk.transpose((2, 0, 1))).long() + + sample = {'img': img, 'msk': msk, 'lbl_msk': lbl_msk, 'fn': fn, 'msk_loc': msk} + return sample + + +def validate(model, data_loader): + dices0 = [] + + tp = np.zeros((4,)) + fp = np.zeros((4,)) + fn = np.zeros((4,)) + totalp = np.zeros((4,)) + + _thr = 0.3 + data_loader = tqdm(data_loader) + with torch.no_grad(): + for i, sample in enumerate(data_loader): + msks = sample["msk"].numpy() + lbl_msk = sample["lbl_msk"].numpy() + imgs = sample["img"].cuda(non_blocking=True) + # msk_loc = sample["msk_loc"].numpy() * 1 + out = model(imgs) + + # msk_pred = msk_loc + msk_pred = torch.sigmoid(out).cpu().numpy()[:, 0, ...] + msk_damage_pred = torch.sigmoid(out).cpu().numpy()[:, 1:, ...] + + for j in range(msks.shape[0]): + dices0.append(dice(msks[j, 0], msk_pred[j] > _thr)) + targ = lbl_msk[j][lbl_msk[j, 0] > 0] + pred = msk_damage_pred[j].argmax(axis=0) + pred = pred * (msk_pred[j] > _thr) + pred = pred[lbl_msk[j, 0] > 0] + for c in range(4): + tp[c] += np.logical_and(pred == c, targ == c).sum() + fn[c] += np.logical_and(pred != c, targ == c).sum() + fp[c] += np.logical_and(pred == c, targ != c).sum() + totalp += (targ == c).sum() + + d0 = np.mean(dices0) + + f1_sc = np.zeros((4,)) + for c in range(4): + f1_sc[c] = 2 * tp[c] / (2 * tp[c] + fp[c] + fn[c]) + f1 = 4 / np.sum(1.0 / (f1_sc + 1e-6)) + + f1_sc_wt = np.zeros((4,)) + totalp = totalp/sum(totalp) + for c in range(4): + f1_sc_wt[c] = totalp[c] * 2 * tp[c] / (2 * tp[c] + fp[c] + fn[c]) + f1_wt = 1 / np.sum(1.0 / (f1_sc_wt + 1e-6)) + + sc = 0.3 * d0 + 0.7 * f1 + print("Val Score: {}, Dice: {}, F1: {}, F1wt: {}, F1_0: {}, F1_1: {}, F1_2: {}, F1_3: {}".format(sc, d0, f1, f1_wt, f1_sc[0], f1_sc[1], f1_sc[2], f1_sc[3])) + return sc + + +def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch): + model = model.eval() + d = validate(model, data_loader=data_val) + + if d > best_score: + torch.save({ + 'epoch': current_epoch + 1, + 'state_dict': model.state_dict(), + 'best_score': d, + 'optimizer' : optimizer.state_dict(), + }, path.join(models_folder, snapshot_name)) + best_score = d + + print("score: {}\tscore_best: {}".format(d, best_score)) + return best_score + + +def train_epoch(current_epoch, seg_loss, ce_loss, model, optimizer, scheduler, train_data_loader): + losses = AverageMeter() + losses1 = AverageMeter() + losses2 = AverageMeter() + dices = AverageMeter() + + iterator = tqdm(train_data_loader) + # iterator = train_data_loader + model.train() + for i, sample in enumerate(iterator): + imgs = sample["img"].cuda(non_blocking=True) + msks = sample["msk"].cuda(non_blocking=True) + lbl_msk = sample["lbl_msk"].cuda(non_blocking=True) + + with amp.autocast(): + out = model(imgs) + + loss0 = seg_loss(out[:, 0, ...], msks[:, 0, ...]) + loss1 = seg_loss(out[:, 1, ...], msks[:, 1, ...]) + loss2 = seg_loss(out[:, 2, ...], msks[:, 2, ...]) + loss3 = seg_loss(out[:, 3, ...], msks[:, 3, ...]) + loss4 = seg_loss(out[:, 4, ...], msks[:, 4, ...]) + + + bldg_mask = (lbl_msk > 0) + for c in range(5): + out[:, c, ...] = torch.mul(out[:, c, ...], bldg_mask) + true_bldg = lbl_msk + loss5 = ce_loss(out, true_bldg) + + # max 0.59 on fix val + # loss = 0.01 * loss0 + 0.5 * loss1 + 0.5 * loss2 + 0.5 * loss3 + 1.0 * loss4 + 3 * loss5 + # loss = 0.01 * loss0 + 0.1 * loss1 + 0.8 * loss2 + 0.8 * loss3 + 1 * loss4 + loss5 * 8 + + # To improve the dice... + loss = 0.1 * loss0 + 0.1 * loss1 + 0.6 * loss2 + 0.3 * loss3 + 1 * loss4 + loss5 * 8 + + + with torch.no_grad(): + _probs = torch.sigmoid(out[:, 0, ...]) + dice_sc = 1 - dice_round(_probs, msks[:, 0, ...]) + + losses.update(loss.item(), imgs.size(0)) + losses1.update(loss2.item(), imgs.size(0)) #loss5 + losses2.update(loss3.item(), imgs.size(0)) + + dices.update(dice_sc, imgs.size(0)) + + iterator.set_description( + "epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); loss2 {loss1.avg:.4f}; loss3 {loss2.avg:.4f}; Dice {dice.val:.4f} ({dice.avg:.4f})".format( + current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, dice=dices)) + + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + #with amp.scale_loss(loss, optimizer) as scaled_loss: + # scaled_loss.backward() + # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.999) + # optimizer.step() + + scheduler.step(current_epoch) + + print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; loss2 {loss1.avg:.4f}; Dice {dice.avg:.4f}".format( + current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, dice=dices)) + + +if __name__ == '__main__': + t0 = timeit.default_timer() + + makedirs(models_folder, exist_ok=True) + seed = 0 + + cudnn.benchmark = True + batch_size = 16 + val_batch_size = 16 + + file_classes = [] + for fn in tqdm(all_files): + fl = np.zeros((4,), dtype=bool) + msk1 = cv2.imread(fn.replace('/images/', '/masks/').replace('_pre_disaster', '_post_disaster'), cv2.IMREAD_UNCHANGED) + for c in range(1, 5): + fl[c-1] = c in msk1 + file_classes.append(fl) + file_classes = np.asarray(file_classes) + + train_idxs0, val_idxs = train_test_split(np.arange(len(all_files)), test_size=0.1, random_state=seed) + + np.random.seed(seed + 321) + random.seed(seed + 321) + + train_idxs = [] + for i in train_idxs0: + train_idxs.append(i) + if file_classes[i, 1:].max(): + train_idxs.append(i) + if file_classes[i, 3].max(): + train_idxs.append(i) + train_idxs = np.asarray(train_idxs) + steps_per_epoch = len(train_idxs) // batch_size + validation_steps = len(val_idxs) // val_batch_size + + print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps) + + data_train = TrainData(train_idxs) + val_train = ValData(val_idxs) + + train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=6, shuffle=True, pin_memory=False, drop_last=True) + val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=6, shuffle=False, pin_memory=False) + + params = model.parameters() + optimizer = AdamW(params, lr=0.0001, weight_decay=1e-6) + #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") + + scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190], gamma=0.6) + # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[5, 11, 17, 23, 29, 33, 47, 50, 60, 70, 90, 110, 130, 150, 170, 180, 190], gamma=0.5) + + # snap_to_load = 'res34_loc_{}_1_best'.format(seed) + + print("=> loading checkpoint '{}'".format(snap_to_load)) + checkpoint = torch.load(path.join(models_folder, snap_to_load), map_location='cpu') + loaded_dict = checkpoint['state_dict'] + sd = model.state_dict() + for k in model.state_dict(): + if k in loaded_dict and sd[k].size() == loaded_dict[k].size(): + sd[k] = loaded_dict[k] + loaded_dict = sd + model.load_state_dict(loaded_dict) + optimizer.load_state_dict(checkpoint['optimizer']) + print("loaded checkpoint '{}' (epoch {}, best_score {})" + .format(snap_to_load, checkpoint['epoch'], checkpoint['best_score'])) + del loaded_dict + del sd + del checkpoint + + gc.collect() + torch.cuda.empty_cache() + + model = nn.DataParallel(model).cuda() + + seg_loss = ComboLoss({'dice': 1, 'focal': 8}, per_image=False).cuda() + weights_ = torch.tensor([0.01,0.10,1.,0.80,1.]) + ce_loss = nn.CrossEntropyLoss(weight=weights_).cuda() + + best_score = 0 + torch.cuda.empty_cache() + + scaler = amp.GradScaler() + + for epoch in range(100): + train_epoch(epoch, seg_loss, ce_loss, model, optimizer, scheduler, train_data_loader) + if epoch % 2 == 0: + torch.cuda.empty_cache() + best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch) + + elapsed = timeit.default_timer() - t0 + torch.cuda.empty_cache() + print('Time: {:.3f} min'.format(elapsed / 60)) + diff --git a/xBD_code/utils.py b/xBD_code/utils.py new file mode 100644 index 0000000..3ab7c23 --- /dev/null +++ b/xBD_code/utils.py @@ -0,0 +1,180 @@ +import numpy as np +#import cv2 +import torch +#### Augmentations +def shift_image(img, shift_pnt): + M = np.float32([[1, 0, shift_pnt[0]], [0, 1, shift_pnt[1]]]) + res = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]), borderMode=cv2.BORDER_REFLECT_101) + return res + + +def rotate_image(image, angle, scale, rot_pnt): + rot_mat = cv2.getRotationMatrix2D(rot_pnt, angle, scale) + result = cv2.warpAffine(image, rot_mat, (image.shape[1], image.shape[0]), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) #INTER_NEAREST + return result + + +def gauss_noise(img, var=30): + row, col, ch = img.shape + mean = var + sigma = var**0.5 + gauss = np.random.normal(mean,sigma,(row,col,ch)) + gauss = gauss.reshape(row,col,ch) + gauss = (gauss - np.min(gauss)).astype(np.uint8) + return np.clip(img.astype(np.int32) + gauss, 0, 255).astype('uint8') + + +def clahe(img, clipLimit=2.0, tileGridSize=(5,5)): + img_yuv = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) + clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize) + img_yuv[:, :, 0] = clahe.apply(img_yuv[:, :, 0]) + img_output = cv2.cvtColor(img_yuv, cv2.COLOR_LAB2RGB) + return img_output + + +def _blend(img1, img2, alpha): + return np.clip(img1 * alpha + (1 - alpha) * img2, 0, 255).astype('uint8') + + +_alpha = np.asarray([0.114, 0.587, 0.299]).reshape((1, 1, 3)) +def _grayscale(img): + return np.sum(_alpha * img, axis=2, keepdims=True) + + +def saturation(img, alpha): + gs = _grayscale(img) + return _blend(img, gs, alpha) + + +def brightness(img, alpha): + gs = np.zeros_like(img) + return _blend(img, gs, alpha) + + +def contrast(img, alpha): + gs = _grayscale(img) + gs = np.repeat(gs.mean(), 3) + return _blend(img, gs, alpha) + + +def change_hsv(img, h, s, v): + hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + hsv = hsv.astype(int) + hsv[:,:,0] += h + hsv[:,:,0] = np.clip(hsv[:,:,0], 0, 255) + hsv[:,:,1] += s + hsv[:,:,1] = np.clip(hsv[:,:,1], 0, 255) + hsv[:,:,2] += v + hsv[:,:,2] = np.clip(hsv[:,:,2], 0, 255) + hsv = hsv.astype('uint8') + img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + return img + +def shift_channels(img, b_shift, g_shift, r_shift): + img = img.astype(int) + img[:,:,0] += b_shift + img[:,:,0] = np.clip(img[:,:,0], 0, 255) + img[:,:,1] += g_shift + img[:,:,1] = np.clip(img[:,:,1], 0, 255) + img[:,:,2] += r_shift + img[:,:,2] = np.clip(img[:,:,2], 0, 255) + img = img.astype('uint8') + return img + +def invert(img): + return 255 - img + +def channel_shuffle(img): + ch_arr = [0, 1, 2] + np.random.shuffle(ch_arr) + img = img[..., ch_arr] + return img + +####### + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + + +def preprocess_inputs(x): + x = np.asarray(x, dtype='float32') + x /= 127 + x -= 1 + return x + +def preprocess_inputs_2img(x): + x = np.asarray(x, dtype='float32') + x[:,:,:6] /= 127 + x[:,:,:6] -= 1 + return x + +def dice(im1, im2, empty_score=1.0): + """ + Computes the Dice coefficient, a measure of set similarity. + Parameters + ---------- + im1 : array-like, bool + Any array of arbitrary size. If not boolean, will be converted. + im2 : array-like, bool + Any other array of identical size. If not boolean, will be converted. + Returns + ------- + dice : float + Dice coefficient as a float on range [0,1]. + Maximum similarity = 1 + No similarity = 0 + Both are empty (sum eq to zero) = empty_score + """ + im1 = np.asarray(im1).astype(np.bool) + im2 = np.asarray(im2).astype(np.bool) + + if im1.shape != im2.shape: + raise ValueError("Shape mismatch: im1 and im2 must have the same shape.") + + im_sum = im1.sum() + im2.sum() + if im_sum == 0: + return empty_score + + # Compute Dice coefficient + intersection = np.logical_and(im1, im2) + + return 2. * intersection.sum() / im_sum + +eps = 1e-6 +def dice_torch(im1, im2, empty_score=1.0): + dice_output = im1 + dice_target = im2 + intersection = torch.sum(torch.logical_and(dice_output, dice_target), dim=0) + union = torch.sum(dice_output) + torch.sum(dice_target, dim=0) + eps + dice_value = ((2 * intersection + eps) / union).mean() + return dice_value + +def iou(im1, im2, empty_score=1.0): + im1 = np.asarray(im1).astype(np.bool) + im2 = np.asarray(im2).astype(np.bool) + + if im1.shape != im2.shape: + raise ValueError("Shape mismatch: im1 and im2 must have the same shape.") + + union = np.logical_or(im1, im2) + im_sum = union.sum() + if im_sum == 0: + return empty_score + + # Compute Dice coefficient + intersection = np.logical_and(im1, im2) + + return intersection.sum() / im_sum diff --git a/xBD_code/zoo/model_transformer.py b/xBD_code/zoo/model_transformer.py new file mode 100644 index 0000000..d822e01 --- /dev/null +++ b/xBD_code/zoo/model_transformer.py @@ -0,0 +1,637 @@ +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +from torchvision.models import resnet34 +import segmentation_models_pytorch as smp +from einops import rearrange + +from importlib.machinery import SourceFileLoader +bitmodule = SourceFileLoader('bitmodule', 'zoo/bit_resnet.py').load_module() + +import matplotlib.pyplot as plt +import random + +class TwoLayerConv2d(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size=3): + super().__init__(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, + padding=kernel_size // 2, stride=1, bias=False), + nn.BatchNorm2d(in_channels), + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + padding=kernel_size // 2, stride=1) + ) + +class TwoLayerConv2d_NoBN(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size=3): + super().__init__(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, + padding=kernel_size // 2, stride=1, bias=False), + # nn.BatchNorm2d(in_channels), + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + padding=kernel_size // 2, stride=1) + ) + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + +class Residual2(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, x2, **kwargs): + return self.fn(x, x2, **kwargs) + x + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class PreNorm2(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, x2, **kwargs): + return self.fn(self.norm(x), self.norm(x2), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Cross_Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim ** -0.5 + + self.softmax = softmax + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, m, mask = None): + b, n, _, h = *x.shape, self.heads + q = self.to_q(x) + k = self.to_k(m) + v = self.to_v(m) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v]) + + dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value = True) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, mask_value) + del mask + + if self.softmax: + attn = dots.softmax(dim=-1) + else: + attn = dots + + out = torch.einsum('bhij,bhjd->bhid', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + # vis_tmp2(out) + return out + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim ** -0.5 + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask = None): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value = True) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, mask_value) + del mask + + attn = dots.softmax(dim=-1) + + out = torch.einsum('bhij,bhjd->bhid', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + ])) + def forward(self, x, mask = None): + for attn, ff in self.layers: + x = attn(x, mask = mask) + x = ff(x) + return x + +class TransformerDecoder(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads, + dim_head = dim_head, dropout = dropout, + softmax=softmax))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + ])) + def forward(self, x, m, mask = None): + """target(query), memory""" + for attn, ff in self.layers: + x = attn(x, m, mask = mask) + x = ff(x) + return x + + +class ResNet(torch.nn.Module): + def __init__(self, input_nc, output_nc, + resnet_stages_num=5, backbone='resnet18', + output_sigmoid=False, if_upsample_2x=True): + """ + In the constructor we instantiate two nn.Linear modules and assign them as + member variables. + """ + super(ResNet, self).__init__() + expand = 1 + if backbone == 'resnet18': + self.resnet = bitmodule.resnet18(pretrained=True, replace_stride_with_dilation=[False,True,True]) + elif backbone == 'resnet34': + self.resnet = bitmodule.resnet34(pretrained=True, replace_stride_with_dilation=[False,True,True]) + else: + raise NotImplementedError + self.relu = nn.ReLU() + self.upsamplex2 = nn.Upsample(scale_factor=2) + self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear') + + self.classifier = TwoLayerConv2d(in_channels=32, out_channels=output_nc) + + self.resnet_stages_num = resnet_stages_num + + self.if_upsample_2x = if_upsample_2x + if self.resnet_stages_num == 5: + layers = 512 * expand + elif self.resnet_stages_num == 4: + layers = 256 * expand + elif self.resnet_stages_num == 3: + layers = 128 * expand + else: + raise NotImplementedError + self.conv_pred = nn.Conv2d(384, 32, kernel_size=3, padding=1) + self.conv_pred2 = nn.Conv2d(96, 32, kernel_size=3, padding=1) + + + def forward_single(self, x): + # resnet layers + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x = self.resnet.relu(x) + x_2 = self.resnet.maxpool(x) + + x_4 = self.resnet.layer1(x_2) # 1/4, in=64, out=64 + x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128 + x_8_pool = self.resnet.maxpool(x_8) + + if self.resnet_stages_num > 3: + x_10 = self.resnet.layer3(x_8_pool) # 1/8, in=128, out=256 + + if self.resnet_stages_num > 4: + raise NotImplementedError + + x = self.upsamplex2(x_10) + x = torch.concat([x, x_8], axis=1) + x = self.conv_pred(x) + x = self.upsamplex2(x) + + x_up2 = torch.concat([x, x_4], axis=1) + x_up2 = self.conv_pred2(x_up2) + + return x, x_up2 + +''' +class BASE_UNet_Transformer(ResNet_Encoder): + """ + Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN + """ + def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5, + token_len=4, token_trans=True, + enc_depth=1, dec_depth=1, + dim_head=64, decoder_dim_head=64, + tokenizer=True, if_upsample_2x=True, + pool_mode='max', pool_size=2, + backbone='resnet18', + decoder_softmax=True, with_decoder_pos=None, + with_decoder=True): + super(BASE_UNet_Transformer, self).__init__(input_nc, output_nc,backbone=backbone, + resnet_stages_num=resnet_stages_num, + if_upsample_2x=if_upsample_2x, + ) + + print("using BiT Transformer !!!!") + + self.token_len = token_len + self.conv_a32 = nn.Conv2d(32, self.token_len, kernel_size=1, padding=0, bias=False) + self.conv_a64 = nn.Conv2d(64, self.token_len, kernel_size=1, padding=0, bias=False) + self.tokenizer = tokenizer + if not self.tokenizer: + # if not use tokenzier,then downsample the feature map into a certain size + self.pooling_size = pool_size + self.pool_mode = pool_mode + self.token_len = self.pooling_size * self.pooling_size + + self.token_trans = token_trans + self.with_decoder = with_decoder + dim = 32 + self.dim = dim + mlp_dim = 2*dim + + self.with_pos = with_pos + if with_pos == 'learned': + self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, dim)) + self.pos_embedding_2 = nn.Parameter(torch.randn(1, self.token_len*2, dim*2)) + decoder_pos_size = 256//4 + self.with_decoder_pos = with_decoder_pos + if self.with_decoder_pos == 'learned': + self.pos_embedding_decoder =nn.Parameter(torch.randn(1, dim, + decoder_pos_size, + decoder_pos_size)) + self.pos_embedding_decoder_2 =nn.Parameter(torch.randn(1, dim*2, + decoder_pos_size, + decoder_pos_size)) + self.enc_depth = enc_depth + self.dec_depth = dec_depth + self.dim_head = dim_head + self.decoder_dim_head = decoder_dim_head + self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=8, + dim_head=self.dim_head, + mlp_dim=mlp_dim, dropout=0) + self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth, + heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0, + softmax=decoder_softmax) + + self.transformer_2 = Transformer(dim=dim*2, depth=self.enc_depth, heads=8, + dim_head=self.dim_head, + mlp_dim=mlp_dim, dropout=0) + self.transformer_decoder_2 = TransformerDecoder(dim=dim*2, depth=self.dec_depth, + heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0, + softmax=decoder_softmax) + + self.conv_remap256 = nn.Conv2d(256, dim, kernel_size=3, padding=1) + self.conv_cat256 = nn.Conv2d(256*2, dim, kernel_size=3, padding=1) + self.conv_remap128 = nn.Conv2d(128, dim, kernel_size=3, padding=1) + self.conv_cat128 = nn.Conv2d(128*2, dim, kernel_size=3, padding=1) + self.conv_remap64 = nn.Conv2d(64, dim*2, kernel_size=3, padding=1) + self.conv_cat64 = nn.Conv2d(64*2, dim*2, kernel_size=3, padding=1) + + self.conv_concat = nn.Conv2d(96, dim, kernel_size=3, padding=1) + + # self.linear = nn.Linear() + + def _forward_semantic_tokens(self, x): + b, c, h, w = x.shape + if c == 32: + spatial_attention = self.conv_a32(x) + else: + spatial_attention = self.conv_a64(x) + spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() + spatial_attention = torch.softmax(spatial_attention, dim=-1) + x = x.view([b, c, -1]).contiguous() + tokens = torch.einsum('bln,bcn->blc', spatial_attention, x) + return tokens + + def _forward_reshape_tokens(self, x): + # b,c,h,w = x.shape + if self.pool_mode == 'max': + x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size]) + elif self.pool_mode == 'ave': + x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size]) + else: + x = x + tokens = rearrange(x, 'b c h w -> b (h w) c') + return tokens + + def _forward_transformer(self, x): + if x.shape[2] == self.dim: + if self.with_pos: + x += self.pos_embedding + x = self.transformer(x) + else: + if self.with_pos: + x += self.pos_embedding_2 + x = self.transformer_2(x) + return x + + def _forward_transformer_decoder(self, x, m): + b, c, h, w = x.shape + if h == self.dim: + # x = x + self.pos_embedding_decoder + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.transformer_decoder(x, m) + x = rearrange(x, 'b (h w) c -> b c h w', h=h) + else: + # x = x + self.pos_embedding_decoder_2 + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.transformer_decoder_2(x, m) + x = rearrange(x, 'b (h w) c -> b c h w', h=h) + return x + + def _forward_simple_decoder(self, x, m): + b, c, h, w = x.shape + b, l, c = m.shape + m = m.expand([h,w,b,l,c]) + m = rearrange(m, 'h w b l c -> l b c h w') + m = m.sum(0) + x = x + m + return x + + def trans_module(self, x1, x2): + # diff + if x1.shape[1] == 256: + conv_map32 = self.conv_remap256 + conv_cat32 = self.conv_cat256 + elif x1.shape[1] == 128: + conv_map32 = self.conv_remap128 + conv_cat32 = self.conv_cat128 + elif x1.shape[1] == 64: + conv_map32 = self.conv_remap64 + conv_cat32 = self.conv_cat64 + + # x = conv_cat32(torch.concat([x1, x2], axis=1)) + x1 = conv_map32(x1) + x2 = conv_map32(x2) + + # forward tokenzier + if self.tokenizer: + token1 = self._forward_semantic_tokens(x1) + token2 = self._forward_semantic_tokens(x2) + else: + token = self._forward_reshape_tokens(x) + + # forward transformer encoder + if self.token_trans: + self.tokens_ = torch.cat([token1, token2], dim=1) + token = self._forward_transformer(self.tokens_) + # print("after transformer encoding token.shape", x.shape, token.shape) + + # forward transformer decoder + x = conv_map32(torch.abs(x1 - x2)) + if self.with_decoder: + x = self._forward_transformer_decoder(x, token) + else: + x = self._forward_simple_decoder(x, token) + return x + + + def forward(self, x): + # forward backbone resnet + x1 = x[:, :3, :, :] + x2 = x[:, 3:, :, :] + x1, x1_64, x1_64_2, x1_128, x1_256 = self.forward_single(x1) + x2, x2_64, x2_64_2, x2_128, x2_256 = self.forward_single(x2) + + # print(x1.shape, x1_64.shape, x1_64_2.shape, x1_128.shape, x1_256.shape) + x_256 = self.trans_module(x1_256, x2_256) + x_256 = self.upsamplex2(x_256) + + # x_128 = self.trans_module(x1_128, x2_128) + x_64 = self.trans_module(x1_64_2, x1_64_2) + + x = torch.concat([x_64, x_256], axis=1) + x = self.conv_concat(x) + x = self.upsamplex4(x) + + # forward small cnn + x = self.classifier(x) + if self.output_sigmoid: + x = self.sigmoid(x) + + # print("OUTPUT", x.shape) + return x +''' + + + +# unet x_4 as spatial encoding to decoder +# without x_4 upsampling +class BASE_Transformer(ResNet): + """ + Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN + """ + def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5, + token_len=4, token_trans=True, + enc_depth=1, dec_depth=1, + dim_head=64, decoder_dim_head=64, + tokenizer=True, if_upsample_2x=True, + pool_mode='max', pool_size=2, + backbone='resnet18', + decoder_softmax=True, with_decoder_pos=None, + with_decoder=True): + super(BASE_Transformer, self).__init__(input_nc, output_nc,backbone=backbone, + resnet_stages_num=resnet_stages_num, + if_upsample_2x=if_upsample_2x, + ) + + print("using BiT Transformer !!!!") + + self.token_len = token_len + self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1, padding=0, bias=False) + self.conv_a_2 = nn.Conv2d(32, self.token_len, kernel_size=1, padding=0, bias=False) + + self.tokenizer = tokenizer + if not self.tokenizer: + # if not use tokenzier,then downsample the feature map into a certain size + self.pooling_size = pool_size + self.pool_mode = pool_mode + self.token_len = self.pooling_size * self.pooling_size + + self.token_trans = token_trans + self.with_decoder = with_decoder + dim = 32 + mlp_dim = 2*dim + + self.with_pos = with_pos + if with_pos is 'learned': + self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, 32)) + self.pos_embedding_2 = nn.Parameter(torch.randn(1, self.token_len*2, 32)) + decoder_pos_size = 256//4 + self.with_decoder_pos = with_decoder_pos + if self.with_decoder_pos == 'learned': + self.pos_embedding_decoder =nn.Parameter(torch.randn(1, 32, + decoder_pos_size, + decoder_pos_size)) + self.pos_embedding_decoder_2 =nn.Parameter(torch.randn(1, 32, + decoder_pos_size, + decoder_pos_size)) + + self.enc_depth = enc_depth + self.dec_depth = dec_depth + self.dim_head = dim_head + self.decoder_dim_head = decoder_dim_head + self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=8, + dim_head=self.dim_head, + mlp_dim=mlp_dim, dropout=0) + self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth, + heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0, + softmax=decoder_softmax) + self.transformer_2 = Transformer(dim=dim, depth=self.enc_depth, heads=2, + dim_head=self.dim_head, + mlp_dim=mlp_dim, dropout=0) + self.transformer_decoder_2 = TransformerDecoder(dim=dim, depth=3, + heads=2, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0, + softmax=decoder_softmax) + + def _forward_semantic_tokens(self, x, level=1): + b, c, h, w = x.shape + if level == 1: + spatial_attention = self.conv_a(x) + else: + spatial_attention = self.conv_a_2(x) + spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() + spatial_attention = torch.softmax(spatial_attention, dim=-1) + x = x.view([b, c, -1]).contiguous() + tokens = torch.einsum('bln,bcn->blc', spatial_attention, x) + return tokens + + def _forward_reshape_tokens(self, x, level=1): + # b,c,h,w = x.shape + if self.pool_mode is 'max': + x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size]) + elif self.pool_mode is 'ave': + x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size]) + else: + x = x + tokens = rearrange(x, 'b c h w -> b (h w) c') + return tokens + + def _forward_transformer(self, x, level=1): + if level == 1: + if self.with_pos: + x += self.pos_embedding + x = self.transformer(x) + else: + if self.with_pos: + x += self.pos_embedding_2 + x = self.transformer_2(x) + return x + + def _forward_transformer_decoder(self, x, m, level=1): + b, c, h, w = x.shape + if level == 1: + if self.with_decoder_pos == 'learned': + x = x + self.pos_embedding_decoder + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.transformer_decoder(x, m) + else: + if self.with_decoder_pos == 'learned': + x = x + self.pos_embedding_decoder_2 + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.transformer_decoder_2(x, m) + + x = rearrange(x, 'b (h w) c -> b c h w', h=h) + return x + + def _forward_simple_decoder(self, x, m, level=1): + b, c, h, w = x.shape + b, l, c = m.shape + m = m.expand([h,w,b,l,c]) + m = rearrange(m, 'h w b l c -> l b c h w') + m = m.sum(0) + x = x + m + return x + + def forward(self, x): + # forward backbone resnet + x1 = x[:, :3, :, :] + x2 = x[:, 3:, :, :] + x1, x1_up2 = self.forward_single(x1) + x2, x2_up2 = self.forward_single(x2) + + # forward tokenzier + if self.tokenizer: + token1 = self._forward_semantic_tokens(x1) + token2 = self._forward_semantic_tokens(x2) + else: + token1 = self._forward_reshape_tokens(x1) + token2 = self._forward_reshape_tokens(x2) + # forward transformer encoder + self.tokens_ = torch.cat([token1, token2], dim=1) + self.tokens = self._forward_transformer(self.tokens_) + token1, token2 = self.tokens.chunk(2, dim=1) + # forward transformer decoder + x1 = self._forward_transformer_decoder(x1, token1) + x2 = self._forward_transformer_decoder(x2, token2) + # feature differencing + x_out1 = torch.abs(x1 - x2) + # x_out1 = self.upsamplex2(x) + + # print("1st: after diff and upsample2", x_out1.shape) + + # forward tokenzier + token1 = self._forward_semantic_tokens(x1_up2, level=2) + token2 = self._forward_semantic_tokens(x2_up2, level=2) + # forward transformer encoder + self.tokens_ = torch.cat([token1, token2], dim=1) + self.tokens = self._forward_transformer(self.tokens_, level=2) + token1, token2 = self.tokens.chunk(2, dim=1) + x1 = self._forward_transformer_decoder(x1_up2, token1, level=2) + x2 = self._forward_transformer_decoder(x2_up2, token2, level=2) + + # feature differencing + x_out2 = torch.abs(x1 - x2) + # print("2nd: after diff and upsample2", x_out2.shape) + + x = x_out1 + x_out2 + # x = self.upsamplex2(x) + + x = self.upsamplex4(x) + + # forward small cnn + x = self.classifier(x) + return x diff --git a/xBD_code/zoo/model_transformer_encoding.py b/xBD_code/zoo/model_transformer_encoding.py new file mode 100644 index 0000000..351c696 --- /dev/null +++ b/xBD_code/zoo/model_transformer_encoding.py @@ -0,0 +1,506 @@ +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +from torchvision.models import resnet34 +#import segmentation_models_pytorch as smp +from einops import rearrange + +from importlib.machinery import SourceFileLoader +bitmodule = SourceFileLoader('bitmodule', 'zoo/bit_resnet.py').load_module() +from torchvision.models import efficientnet_b0, resnet18 + +import matplotlib.pyplot as plt +import random + +class TwoLayerConv2d(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size=3): + super().__init__(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, + padding=kernel_size // 2, stride=1, bias=False), + nn.BatchNorm2d(in_channels), + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + padding=kernel_size // 2, stride=1)) + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + +class Residual2(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, x2, **kwargs): + return self.fn(x, x2, **kwargs) + x + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class PreNorm2(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, x2, **kwargs): + return self.fn(self.norm(x), self.norm(x2), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Cross_Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim ** -0.5 + + self.softmax = softmax + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, m, mask = None): + b, n, _, h = *x.shape, self.heads + q = self.to_q(x) + k = self.to_k(m) + v = self.to_v(m) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v]) + + dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value = True) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, mask_value) + del mask + + if self.softmax: + attn = dots.softmax(dim=-1) + else: + attn = dots + + out = torch.einsum('bhij,bhjd->bhid', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + # vis_tmp2(out) + return out + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim ** -0.5 + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask = None): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value = True) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, mask_value) + del mask + + attn = dots.softmax(dim=-1) + + out = torch.einsum('bhij,bhjd->bhid', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + ])) + def forward(self, x, mask = None): + for attn, ff in self.layers: + x = attn(x, mask = mask) + x = ff(x) + return x + +class TransformerDecoder(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads, + dim_head = dim_head, dropout = dropout, + softmax=softmax))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + ])) + def forward(self, x, m, mask = None): + """target(query), memory""" + for attn, ff in self.layers: + x = attn(x, m, mask = mask) + x = ff(x) + return x + + +class ResNet_UNet(torch.nn.Module): + def __init__(self, input_nc, output_nc, + resnet_stages_num=5, backbone='resnet18', + output_sigmoid=False, if_upsample_2x=True): + """ + In the constructor we instantiate two nn.Linear modules and assign them as + member variables. + """ + super(ResNet_UNet, self).__init__() + expand = 1 + if backbone == 'resnet18': + self.resnet = bitmodule.resnet18(pretrained=True, replace_stride_with_dilation=[False,True,True]) + elif backbone == 'resnet34': + self.resnet = bitmodule.resnet34(pretrained=True, replace_stride_with_dilation=[False,True,True]) + else: + raise NotImplementedError + self.relu = nn.ReLU() + self.upsamplex2 = nn.Upsample(scale_factor=2) + self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear') + + self.resnet_stages_num = resnet_stages_num + + self.if_upsample_2x = if_upsample_2x + if self.resnet_stages_num == 5: + layers = 512 * expand + elif self.resnet_stages_num == 4: + layers = 256 * expand + elif self.resnet_stages_num == 3: + layers = 128 * expand + else: + raise NotImplementedError + self.conv_pred = nn.Conv2d(384, 32, kernel_size=3, padding=1) + + def forward_single(self, x): + # resnet layers + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x_2 = self.resnet.relu(x) + x_2_pool = self.resnet.maxpool(x) + + x_4 = self.resnet.layer1(x_2_pool) # 1/4, in=64, out=64 + + x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128 + x_8_pool = self.resnet.maxpool(x_8) + + x_10 = self.resnet.layer3(x_8_pool) # 1/8, in=128, out=256 + + if self.resnet_stages_num > 4: + raise NotImplementedError + + # print(x_2.shape, x_4.shape, x_8.shape, x_10.shape) + x = self.upsamplex2(x_10) + + return x_2, x_4, x_8, x_10 + + +# unet x_4 as spatial encoding to decoder +# without x_4 upsampling +class BASE_Transformer_UNet(ResNet_UNet): + """ + Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN + """ + def __init__(self, input_nc, output_nc, with_pos=None, resnet_stages_num=5, + token_len=4, token_trans=True, + enc_depth=1, dec_depth=1, + dim_head=64, decoder_dim_head=64, + tokenizer=True, if_upsample_2x=True, + pool_mode='max', pool_size=2, + backbone='resnet18', + decoder_softmax=True, with_decoder_pos=None, + with_decoder=True): + super(BASE_Transformer_UNet, self).__init__(input_nc, output_nc,backbone=backbone, + resnet_stages_num=resnet_stages_num, + if_upsample_2x=if_upsample_2x, + ) + + print("using UNet Transformer !!!!") + + self.token_len = token_len + self.tokenizer = tokenizer + self.token_trans = token_trans + self.with_decoder = with_decoder + self.with_pos = with_pos + + if not self.tokenizer: + # if not use tokenzier,then downsample the feature map into a certain size + self.pooling_size = pool_size + self.pool_mode = pool_mode + self.token_len = self.pooling_size * self.pooling_size + + + # conv squeeze layers before transformer + dim_5, dim_4, dim_3, dim_2 = 32, 32, 32, 32 + self.conv_squeeze_5 = nn.Sequential(nn.Conv2d(256, dim_5, kernel_size=1, padding=0, bias=False), + nn.ReLU()) + self.conv_squeeze_4 = nn.Sequential(nn.Conv2d(128, dim_4, kernel_size=1, padding=0, bias=False), + nn.ReLU()) + self.conv_squeeze_3 = nn.Sequential(nn.Conv2d(64, dim_3, kernel_size=1, padding=0, bias=False), + nn.ReLU()) + self.conv_squeeze_2 = nn.Sequential(nn.Conv2d(64, dim_2, kernel_size=1, padding=0, bias=False), + nn.ReLU()) + self.conv_squeeze_layers = nn.ModuleList([self.conv_squeeze_2, self.conv_squeeze_3, self.conv_squeeze_4, self.conv_squeeze_5]) + + self.conv_token_5 = nn.Conv2d(dim_5, self.token_len, kernel_size=1, padding=0, bias=False) + self.conv_token_4 = nn.Conv2d(dim_4, self.token_len, kernel_size=1, padding=0, bias=False) + self.conv_token_3 = nn.Conv2d(dim_3, self.token_len, kernel_size=1, padding=0, bias=False) + self.conv_token_2 = nn.Conv2d(dim_2, self.token_len, kernel_size=1, padding=0, bias=False) + self.conv_tokens_layers = nn.ModuleList([self.conv_token_2, self.conv_token_3, self.conv_token_4, self.conv_token_5]) + + + self.conv_decode_5 = nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False) + self.conv_decode_4 = nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False) + self.conv_decode_3 = nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False) + self.conv_decode_2 = nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False) + self.conv_decode_layers = nn.ModuleList([self.conv_decode_2, self.conv_decode_3, self.conv_decode_4, self.conv_decode_5]) + + + if with_pos is 'learned': + self.pos_embedding_5 = nn.Parameter(torch.randn(1, self.token_len*2, dim_5)) + self.pos_embedding_4 = nn.Parameter(torch.randn(1, self.token_len*2, dim_4)) + self.pos_embedding_3 = nn.Parameter(torch.randn(1, self.token_len*2, dim_3)) + #self.pos_embedding_2 = nn.Parameter(torch.randn(1, self.token_len*2, dim_2)) + #self.pos_embedding_layers = nn.ParameterList([self.pos_embedding_2, self.pos_embedding_3, self.pos_embedding_4, self.pos_embedding_5]) + + decoder_pos_size = 256//2 + self.with_decoder_pos = with_decoder_pos + if self.with_decoder_pos == 'learned': + self.pos_embedding_decoder_5 =nn.Parameter(torch.randn(1, dim_5, 16, 16)) + self.pos_embedding_decoder_4 =nn.Parameter(torch.randn(1, dim_4, 32, 32)) + self.pos_embedding_decoder_3 =nn.Parameter(torch.randn(1, dim_3, 64, 64)) + #self.pos_embedding_decoder_2 =nn.Parameter(torch.randn(1, dim_2, decoder_pos_size, decoder_pos_size)) + #self.pos_embedding_decoder_layers = nn.ParameterList([self.pos_embedding_decoder_2, self.pos_embedding_decoder_3, self.pos_embedding_decoder_4, self.pos_embedding_decoder_5]) + self.enc_depth = enc_depth + self.dec_depth = dec_depth + self.dim_head = dim_head + self.decoder_dim_head = decoder_dim_head + self.transformer_5 = Transformer(dim=dim_5, depth=self.enc_depth, heads=4, + dim_head=self.dim_head, mlp_dim=dim_5, dropout=0) + self.transformer_decoder_5 = TransformerDecoder(dim=dim_5, depth=4, heads=4, + dim_head=self.decoder_dim_head, mlp_dim=dim_5, dropout=0, softmax=decoder_softmax) + self.transformer_4 = Transformer(dim=dim_4, depth=self.enc_depth, heads=4, + dim_head=self.dim_head, mlp_dim=dim_4, dropout=0) + self.transformer_decoder_4 = TransformerDecoder(dim=dim_4, depth=4, heads=4, dim_head=self.decoder_dim_head, + mlp_dim=dim_4, dropout=0, softmax=decoder_softmax) + self.transformer_3 = Transformer(dim=dim_3, depth=self.enc_depth, heads=8, + dim_head=self.dim_head, mlp_dim=dim_3, dropout=0) + self.transformer_decoder_3 = TransformerDecoder(dim=dim_3, depth=8, heads=8, dim_head=self.decoder_dim_head, + mlp_dim=dim_3, dropout=0, softmax=decoder_softmax) + self.transformer_2 = Transformer(dim=dim_2, depth=self.enc_depth, heads=1, + dim_head=32, mlp_dim=dim_2, dropout=0) + self.transformer_decoder_2 = TransformerDecoder(dim=dim_2, depth=1, heads=1, dim_head=32, + mlp_dim=dim_2, dropout=0, softmax=decoder_softmax) + self.transformer_layers = nn.ModuleList([self.transformer_2, self.transformer_3, self.transformer_4, self.transformer_5]) + self.transformer_decoder_layers = nn.ModuleList([self.transformer_decoder_2, self.transformer_decoder_3, self.transformer_decoder_4, self.transformer_decoder_5]) + + self.conv_layer2_0 = TwoLayerConv2d(in_channels=128, out_channels=32, kernel_size=3) + # self.conv_layer2 = nn.Conv2d(in_channels=48, out_channels=16, kernel_size=3, padding=1) + # self.conv_layer3 = nn.Conv2d(in_channels=48, out_channels=16, kernel_size=3, padding=1) + # self.conv_layer4 = nn.Conv2d(in_channels=64, out_channels=16, kernel_size=3, padding=1) + # self.classifier = nn.Conv2d(in_channels=16, out_channels=output_nc, kernel_size=3, padding=1) + self.conv_layer2 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1), + nn.ReLU()) + self.conv_layer3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1), + nn.ReLU()) + self.conv_layer4 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1), + nn.ReLU()) + self.classifier = nn.Conv2d(in_channels=32, out_channels=5, kernel_size=3, padding=1) + + #self.classifier = nn.Conv2d(in_channels=32, out_channels=output_nc, kernel_size=3, padding=1) + #self.seg_head = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, padding=1) + #self.cls_head = nn.Conv2d(in_channels=32, out_channels=4, kernel_size=3, padding=1) + + + def _forward_semantic_tokens(self, x, layer=None): + b, c, h, w = x.shape + spatial_attention = self.conv_tokens_layers[layer](x) + spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() + spatial_attention = torch.softmax(spatial_attention, dim=-1) + x = x.view([b, c, -1]).contiguous() + tokens = torch.einsum('bln,bcn->blc', spatial_attention, x) + return tokens + + def _forward_transformer(self, x, layer): + if self.with_pos: + if layer == 5: + x = x + self.pos_embedding_5 + if layer == 4: + x = x + self.pos_embedding_4 + if layer == 3: + x = x + self.pos_embedding_3 + #x += self.pos_embedding_layers[layer] + x = self.transformer_layers[layer](x) + return x + + def _forward_transformer_decoder(self, x, m, layer): + b, c, h, w = x.shape + if self.with_decoder_pos == 'learned': + if layer == 5: + x = x + self.pos_embedding_decoder_5 + if layer == 4: + x = x + self.pos_embedding_decoder_4 + if layer == 3: + x = x + self.pos_embedding_decoder_3 + #x = x + self.pos_embedding_decoder_layers[layer] + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.transformer_decoder_layers[layer](x, m) + x = rearrange(x, 'b (h w) c -> b c h w', h=h) + return x + + def _forward_trans_module(self, x1, x2, layer): + x1 = self.conv_squeeze_layers[layer](x1) + x2 = self.conv_squeeze_layers[layer](x2) + token1 = self._forward_semantic_tokens(x1, layer) + token2 = self._forward_semantic_tokens(x2, layer) + self.tokens_ = torch.cat([token1, token2], dim=1) + self.tokens = self._forward_transformer(self.tokens_, layer) + token1, token2 = self.tokens.chunk(2, dim=1) + # x1 = self._forward_transformer_decoder(x1, token1, layer) + # x2 = self._forward_transformer_decoder(x2, token2, layer) + # return torch.abs(x1 - x2) + + # V1, V2 + # x1 = self._forward_transformer_decoder(x1, token2, layer) + # x2 = self._forward_transformer_decoder(x2, token1, layer) + # return torch.add(x1, x2) + + # # V3 + diff_token = torch.abs(token2 - token1) + diff_x = self.conv_decode_layers[layer](torch.cat([x1,x2], axis=1)) + x = self._forward_transformer_decoder(diff_x, diff_token, layer) + return x + + + def forward(self, x): + # forward backbone resnet + x1 = x[:, :3, :, :] + x2 = x[:, 3:, :, :] + a_128, a_64, a_32, a_16 = self.forward_single(x1) + b_128, b_64, b_32, b_16 = self.forward_single(x2) + + # level 5 in=256x16x16 out=32x16x16 + x1, x2 = a_16, b_16 + out_5 = self._forward_trans_module(x1, x2, layer=3) + out_5 = self.upsamplex2(out_5) + + # level 4: in=128x32x32 out=32x32x32 + x1, x2 = a_32, b_32 + out_4 = self._forward_trans_module(x1, x2, layer=2) + out_4 = out_4 + out_5 + # out_4 = self.conv_layer4(torch.cat([out_4, out_5], axis=1)) + out_4 = self.upsamplex2(out_4) + out_4 = self.conv_layer4(out_4) + + # level 3: in=64x64x64 out=32x64x64 + x1, x2 = a_64, b_64 + out_3 = self._forward_trans_module(x1, x2, layer=1) + out_3 = out_3 + out_4 + # out_3 = self.conv_layer3(torch.cat([out_3, out_4], axis=1)) + out_3 = self.upsamplex2(out_3) + out_3 = self.conv_layer3(out_3) + + # level 2: in=64x128x128 + out_2 = self.conv_layer2_0(torch.cat([a_128, b_128], 1)) + out_2 = out_2 + out_3 + # out_2 = self.conv_layer2(torch.cat([out_2, out_3], axis=1)) + out_2 = self.upsamplex2(out_2) + out_2 = self.conv_layer2(out_2) + # print(out_2.shape, out_3.shape, out_4.shape, out_5.shape) + # forward small cnn + x = self.classifier(out_2) + # x_seg = self.seg_head(out_2) + # x_cls = self.cls_head(out_2) + # x = torch.cat([x_seg, x_cls], axis=1) + return x + + + + +class Discriminator(torch.nn.Module): + def __init__(self, input_nc=5): + super(Discriminator, self).__init__() + self.pre_process = nn.Conv2d(in_channels=5, out_channels=3, kernel_size=3, padding=0) + self.backbone = resnet18(pretrained=True) + + def forward(self, x): + x = self.pre_process(x) + x = self.backbone(x) + return x + + + +class UNet_Loc(ResNet_UNet): + def __init__(self, input_nc, output_nc, + resnet_stages_num=5, backbone='resnet18'): + super(UNet_Loc, self).__init__() + + def forward(self, x): + # forward backbone resnet + a_128, a_64, a_32, a_16 = self.forward_single(x) + + # level 5 in=256x16x16 out=32x16x16 + out_5 = self.upsamplex2(a_16) + + out_4 = self.conv_layer4(torch.cat([out_4, out_5], axis=1)) + out_4 = self.upsamplex2(out_4) + + # level 3: in=64x64x64 out=32x64x64 + x1, x2 = a_64, b_64 + out_3 = self._forward_trans_module(x1, x2, layer=1) + out_3 = out_3 + out_4 + # out_3 = self.conv_layer3(torch.cat([out_3, out_4], axis=1)) + out_3 = self.upsamplex2(out_3) + + # level 2: in=64x128x128 + out_2 = self.conv_layer2_0(torch.cat([a_128, b_128], 1)) + out_2 = out_2 + out_3 + # out_2 = self.conv_layer2(torch.cat([out_2, out_3], axis=1)) + out_2 = self.upsamplex2(out_2) + + # print(out_2.shape, out_3.shape, out_4.shape, out_5.shape) + # forward small cnn + x = self.classifier(out_2) + return x diff --git a/xBD_code/zoo/model_transformer_encoding_BKP_Apr6.py b/xBD_code/zoo/model_transformer_encoding_BKP_Apr6.py new file mode 100644 index 0000000..a89fa69 --- /dev/null +++ b/xBD_code/zoo/model_transformer_encoding_BKP_Apr6.py @@ -0,0 +1,472 @@ +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +from torchvision.models import resnet34 +import segmentation_models_pytorch as smp +from einops import rearrange + +from importlib.machinery import SourceFileLoader +bitmodule = SourceFileLoader('bitmodule', 'zoo/bit_resnet.py').load_module() + +import matplotlib.pyplot as plt +import random + +class TwoLayerConv2d(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size=3): + super().__init__(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, + padding=kernel_size // 2, stride=1, bias=False), + nn.BatchNorm2d(in_channels), + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + padding=kernel_size // 2, stride=1) + ) + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + +class Residual2(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, x2, **kwargs): + return self.fn(x, x2, **kwargs) + x + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class PreNorm2(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, x2, **kwargs): + return self.fn(self.norm(x), self.norm(x2), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Cross_Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim ** -0.5 + + self.softmax = softmax + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, m, mask = None): + b, n, _, h = *x.shape, self.heads + q = self.to_q(x) + k = self.to_k(m) + v = self.to_v(m) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v]) + + dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value = True) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, mask_value) + del mask + + if self.softmax: + attn = dots.softmax(dim=-1) + else: + attn = dots + + out = torch.einsum('bhij,bhjd->bhid', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + # vis_tmp2(out) + return out + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim ** -0.5 + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask = None): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value = True) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, mask_value) + del mask + + attn = dots.softmax(dim=-1) + + out = torch.einsum('bhij,bhjd->bhid', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + ])) + def forward(self, x, mask = None): + for attn, ff in self.layers: + x = attn(x, mask = mask) + x = ff(x) + return x + +class TransformerDecoder(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads, + dim_head = dim_head, dropout = dropout, + softmax=softmax))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + ])) + def forward(self, x, m, mask = None): + """target(query), memory""" + for attn, ff in self.layers: + x = attn(x, m, mask = mask) + x = ff(x) + return x + + +class ResNet_UNet(torch.nn.Module): + def __init__(self, input_nc, output_nc, + resnet_stages_num=5, backbone='resnet18', + output_sigmoid=False, if_upsample_2x=True): + """ + In the constructor we instantiate two nn.Linear modules and assign them as + member variables. + """ + super(ResNet_UNet, self).__init__() + expand = 1 + if backbone == 'resnet18': + self.resnet = bitmodule.resnet18(pretrained=True, replace_stride_with_dilation=[False,True,True]) + elif backbone == 'resnet34': + self.resnet = bitmodule.resnet34(pretrained=True, replace_stride_with_dilation=[False,True,True]) + else: + raise NotImplementedError + self.relu = nn.ReLU() + self.upsamplex2 = nn.Upsample(scale_factor=2) + self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear') + + self.resnet_stages_num = resnet_stages_num + + self.if_upsample_2x = if_upsample_2x + if self.resnet_stages_num == 5: + layers = 512 * expand + elif self.resnet_stages_num == 4: + layers = 256 * expand + elif self.resnet_stages_num == 3: + layers = 128 * expand + else: + raise NotImplementedError + self.conv_pred = nn.Conv2d(384, 32, kernel_size=3, padding=1) + + def forward_single(self, x): + # resnet layers + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x_2 = self.resnet.relu(x) + x_2_pool = self.resnet.maxpool(x) + + x_4 = self.resnet.layer1(x_2_pool) # 1/4, in=64, out=64 + + x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128 + x_8_pool = self.resnet.maxpool(x_8) + + x_10 = self.resnet.layer3(x_8_pool) # 1/8, in=128, out=256 + + if self.resnet_stages_num > 4: + raise NotImplementedError + + # print(x_2.shape, x_4.shape, x_8.shape, x_10.shape) + x = self.upsamplex2(x_10) + + return x_2, x_4, x_8, x_10 + + +# unet x_4 as spatial encoding to decoder +# without x_4 upsampling +class BASE_Transformer_UNet(ResNet_UNet): + """ + Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN + """ + def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5, + token_len=4, token_trans=True, + enc_depth=1, dec_depth=1, + dim_head=64, decoder_dim_head=64, + tokenizer=True, if_upsample_2x=True, + pool_mode='max', pool_size=2, + backbone='resnet18', + decoder_softmax=True, with_decoder_pos=None, + with_decoder=True): + super(BASE_Transformer_UNet, self).__init__(input_nc, output_nc,backbone=backbone, + resnet_stages_num=resnet_stages_num, + if_upsample_2x=if_upsample_2x, + ) + + print("using UNet Transformer !!!!") + + self.token_len = token_len + self.tokenizer = tokenizer + self.token_trans = token_trans + self.with_decoder = with_decoder + self.with_pos = with_pos + + if not self.tokenizer: + # if not use tokenzier,then downsample the feature map into a certain size + self.pooling_size = pool_size + self.pool_mode = pool_mode + self.token_len = self.pooling_size * self.pooling_size + + + # conv squeeze layers before transformer + dim_5, dim_4, dim_3, dim_2 = 32, 32, 32, 32 + self.conv_squeeze_5 = nn.Sequential(nn.Conv2d(256, dim_5, kernel_size=1, padding=0, bias=False), + nn.ReLU()) + self.conv_squeeze_4 = nn.Sequential(nn.Conv2d(128, dim_4, kernel_size=1, padding=0, bias=False), + nn.ReLU()) + self.conv_squeeze_3 = nn.Sequential(nn.Conv2d(64, dim_3, kernel_size=1, padding=0, bias=False), + nn.ReLU()) + self.conv_squeeze_2 = nn.Sequential(nn.Conv2d(64, dim_2, kernel_size=1, padding=0, bias=False), + nn.ReLU()) + self.conv_squeeze_layers = [self.conv_squeeze_2, self.conv_squeeze_3, self.conv_squeeze_4, self.conv_squeeze_5] + + self.conv_token_5 = nn.Conv2d(dim_5, self.token_len, kernel_size=1, padding=0, bias=False) + self.conv_token_4 = nn.Conv2d(dim_4, self.token_len, kernel_size=1, padding=0, bias=False) + self.conv_token_3 = nn.Conv2d(dim_3, self.token_len, kernel_size=1, padding=0, bias=False) + self.conv_token_2 = nn.Conv2d(dim_2, self.token_len, kernel_size=1, padding=0, bias=False) + self.conv_tokens_layers = [self.conv_token_2, self.conv_token_3, self.conv_token_4, self.conv_token_5] + + + self.conv_decode_5 = nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False) + self.conv_decode_4 = nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False) + self.conv_decode_3 = nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False) + self.conv_decode_2 = nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False) + self.conv_decode_layers = [self.conv_decode_2, self.conv_decode_3, self.conv_decode_4, self.conv_decode_5] + + + if with_pos is 'learned': + self.pos_embedding_5 = nn.Parameter(torch.randn(1, self.token_len*2, dim_5)) + self.pos_embedding_4 = nn.Parameter(torch.randn(1, self.token_len*2, dim_4)) + self.pos_embedding_3 = nn.Parameter(torch.randn(1, self.token_len*2, dim_3)) + self.pos_embedding_2 = nn.Parameter(torch.randn(1, self.token_len*2, dim_2)) + self.pos_embedding_layers = [self.pos_embedding_2, self.pos_embedding_3, self.pos_embedding_4, self.pos_embedding_5] + + decoder_pos_size = 256//4 + self.with_decoder_pos = with_decoder_pos + if self.with_decoder_pos == 'learned': + self.pos_embedding_decoder_5 =nn.Parameter(torch.randn(1, dim_5, 16, 16)) + self.pos_embedding_decoder_4 =nn.Parameter(torch.randn(1, dim_4, 32, 32)) + self.pos_embedding_decoder_3 =nn.Parameter(torch.randn(1, dim_3, 64, 64)) + self.pos_embedding_decoder_2 =nn.Parameter(torch.randn(1, dim_2, decoder_pos_size, decoder_pos_size)) + self.pos_embedding_decoder_layers = [self.pos_embedding_decoder_2, self.pos_embedding_decoder_3, self.pos_embedding_decoder_4, self.pos_embedding_decoder_5] + + self.enc_depth = enc_depth + self.dec_depth = dec_depth + self.dim_head = dim_head + self.decoder_dim_head = decoder_dim_head + self.transformer_5 = Transformer(dim=dim_5, depth=self.enc_depth, heads=4, + dim_head=self.dim_head, mlp_dim=dim_5, dropout=0) + self.transformer_decoder_5 = TransformerDecoder(dim=dim_5, depth=4, heads=4, + dim_head=self.decoder_dim_head, mlp_dim=dim_5, dropout=0, softmax=decoder_softmax) + self.transformer_4 = Transformer(dim=dim_4, depth=self.enc_depth, heads=4, + dim_head=self.dim_head, mlp_dim=dim_4, dropout=0) + self.transformer_decoder_4 = TransformerDecoder(dim=dim_4, depth=4, heads=4, dim_head=self.decoder_dim_head, + mlp_dim=dim_4, dropout=0, softmax=decoder_softmax) + self.transformer_3 = Transformer(dim=dim_3, depth=self.enc_depth, heads=8, + dim_head=self.dim_head, mlp_dim=dim_3, dropout=0) + self.transformer_decoder_3 = TransformerDecoder(dim=dim_3, depth=8, heads=8, dim_head=self.decoder_dim_head, + mlp_dim=dim_3, dropout=0, softmax=decoder_softmax) + self.transformer_2 = Transformer(dim=dim_2, depth=self.enc_depth, heads=1, + dim_head=32, mlp_dim=dim_2, dropout=0) + self.transformer_decoder_2 = TransformerDecoder(dim=dim_2, depth=1, heads=1, dim_head=32, + mlp_dim=dim_2, dropout=0, softmax=decoder_softmax) + self.transformer_layers = [self.transformer_2, self.transformer_3, self.transformer_4, self.transformer_5] + self.transformer_decoder_layers = [self.transformer_decoder_2, self.transformer_decoder_3, self.transformer_decoder_4, self.transformer_decoder_5] + + self.conv_layer2_0 = TwoLayerConv2d(in_channels=128, out_channels=32, kernel_size=3) + # self.conv_layer2 = nn.Conv2d(in_channels=48, out_channels=16, kernel_size=3, padding=1) + # self.conv_layer3 = nn.Conv2d(in_channels=48, out_channels=16, kernel_size=3, padding=1) + # self.conv_layer4 = nn.Conv2d(in_channels=64, out_channels=16, kernel_size=3, padding=1) + # self.classifier = nn.Conv2d(in_channels=16, out_channels=output_nc, kernel_size=3, padding=1) + + self.classifier = nn.Conv2d(in_channels=32, out_channels=output_nc, kernel_size=3, padding=1) + self.seg_head = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, padding=1) + self.cls_head = nn.Conv2d(in_channels=32, out_channels=4, kernel_size=3, padding=1) + + + def _forward_semantic_tokens(self, x, layer=None): + b, c, h, w = x.shape + spatial_attention = self.conv_tokens_layers[layer](x) + spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() + spatial_attention = torch.softmax(spatial_attention, dim=-1) + x = x.view([b, c, -1]).contiguous() + tokens = torch.einsum('bln,bcn->blc', spatial_attention, x) + return tokens + + def _forward_transformer(self, x, layer): + if self.with_pos: + x += self.pos_embedding_layers[layer] + x = self.transformer_layers[layer](x) + return x + + def _forward_transformer_decoder(self, x, m, layer): + b, c, h, w = x.shape + if self.with_decoder_pos == 'learned': + x = x + self.pos_embedding_decoder_layers[layer] + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.transformer_decoder_layers[layer](x, m) + x = rearrange(x, 'b (h w) c -> b c h w', h=h) + return x + + def _forward_trans_module(self, x1, x2, layer): + x1 = self.conv_squeeze_layers[layer](x1) + x2 = self.conv_squeeze_layers[layer](x2) + token1 = self._forward_semantic_tokens(x1, layer) + token2 = self._forward_semantic_tokens(x2, layer) + self.tokens_ = torch.cat([token1, token2], dim=1) + self.tokens = self._forward_transformer(self.tokens_, layer) + token1, token2 = self.tokens.chunk(2, dim=1) + # x1 = self._forward_transformer_decoder(x1, token1, layer) + # x2 = self._forward_transformer_decoder(x2, token2, layer) + # return torch.abs(x1 - x2) + + # V1, V2 + # x1 = self._forward_transformer_decoder(x1, token2, layer) + # x2 = self._forward_transformer_decoder(x2, token1, layer) + # return torch.add(x1, x2) + + # # V3 + diff_token = torch.abs(token2 - token1) + diff_x = self.conv_decode_layers[layer](torch.cat([x1,x2], axis=1)) + x = self._forward_transformer_decoder(diff_x, diff_token, layer) + return x + + + def forward(self, x): + # forward backbone resnet + x1 = x[:, :3, :, :] + x2 = x[:, 3:, :, :] + a_128, a_64, a_32, a_16 = self.forward_single(x1) + b_128, b_64, b_32, b_16 = self.forward_single(x2) + + # level 5 in=256x16x16 out=32x16x16 + x1, x2 = a_16, b_16 + out_5 = self._forward_trans_module(x1, x2, layer=3) + out_5 = self.upsamplex2(out_5) + + # level 4: in=128x32x32 out=32x32x32 + x1, x2 = a_32, b_32 + out_4 = self._forward_trans_module(x1, x2, layer=2) + out_4 = out_4 + out_5 + # out_4 = self.conv_layer4(torch.cat([out_4, out_5], axis=1)) + out_4 = self.upsamplex2(out_4) + + # level 3: in=64x64x64 out=32x64x64 + x1, x2 = a_64, b_64 + out_3 = self._forward_trans_module(x1, x2, layer=1) + out_3 = out_3 + out_4 + # out_3 = self.conv_layer3(torch.cat([out_3, out_4], axis=1)) + out_3 = self.upsamplex2(out_3) + + # level 2: in=64x128x128 + out_2 = self.conv_layer2_0(torch.cat([a_128, b_128], 1)) + out_2 = out_2 + out_3 + # out_2 = self.conv_layer2(torch.cat([out_2, out_3], axis=1)) + out_2 = self.upsamplex2(out_2) + + # print(out_2.shape, out_3.shape, out_4.shape, out_5.shape) + # forward small cnn + # x = self.classifier(out_2) + x_seg = self.seg_head(out_2) + x_cls = self.cls_head(out_2) + x = torch.cat([x_seg, x_cls], axis=1) + return x + + + +class UNet_Loc(ResNet_UNet): + def __init__(self, input_nc, output_nc, + resnet_stages_num=5, backbone='resnet18'): + super(UNet_Loc, self).__init__() + + def forward(self, x): + # forward backbone resnet + a_128, a_64, a_32, a_16 = self.forward_single(x) + + # level 5 in=256x16x16 out=32x16x16 + out_5 = self.upsamplex2(a_16) + + out_4 = self.conv_layer4(torch.cat([out_4, out_5], axis=1)) + out_4 = self.upsamplex2(out_4) + + # level 3: in=64x64x64 out=32x64x64 + x1, x2 = a_64, b_64 + out_3 = self._forward_trans_module(x1, x2, layer=1) + out_3 = out_3 + out_4 + # out_3 = self.conv_layer3(torch.cat([out_3, out_4], axis=1)) + out_3 = self.upsamplex2(out_3) + + # level 2: in=64x128x128 + out_2 = self.conv_layer2_0(torch.cat([a_128, b_128], 1)) + out_2 = out_2 + out_3 + # out_2 = self.conv_layer2(torch.cat([out_2, out_3], axis=1)) + out_2 = self.upsamplex2(out_2) + + # print(out_2.shape, out_3.shape, out_4.shape, out_5.shape) + # forward small cnn + x = self.classifier(out_2) + return x \ No newline at end of file diff --git a/xBD_code/zoo/models.py b/xBD_code/zoo/models.py new file mode 100644 index 0000000..1e9705b --- /dev/null +++ b/xBD_code/zoo/models.py @@ -0,0 +1,2147 @@ +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +from torchvision.models import resnet34 +from .senet import se_resnext50_32x4d, senet154 +from .dpn import dpn92 +#import segmentation_models_pytorch as smp +from einops import rearrange + +from importlib.machinery import SourceFileLoader +bitmodule = SourceFileLoader('bitmodule', 'zoo/bit_resnet.py').load_module() + +class ConvReluBN(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(ConvReluBN, self).__init__() + self.layer = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + def forward(self, x): + return self.layer(x) + + +class ConvRelu(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(ConvRelu, self).__init__() + self.layer = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=1), + nn.ReLU(inplace=True) + ) + def forward(self, x): + return self.layer(x) + + +class SCSEModule(nn.Module): + # according to https://arxiv.org/pdf/1808.08127.pdf concat is better + def __init__(self, channels, reduction=16, concat=False): + super(SCSEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, + padding=0) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, + padding=0) + self.sigmoid = nn.Sigmoid() + + self.spatial_se = nn.Sequential(nn.Conv2d(channels, 1, kernel_size=1, + stride=1, padding=0, bias=False), + nn.Sigmoid()) + self.concat = concat + + def forward(self, x): + module_input = x + + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + chn_se = self.sigmoid(x) + chn_se = chn_se * module_input + + spa_se = self.spatial_se(module_input) + spa_se = module_input * spa_se + if self.concat: + return torch.cat([chn_se, spa_se], dim=1) + else: + return chn_se + spa_se + + +class SeResNext50_Unet_Loc(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(SeResNext50_Unet_Loc, self).__init__() + + encoder_filters = [64, 256, 512, 1024, 2048] + decoder_filters = np.asarray([64, 96, 128, 256, 512]) // 2 + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + + self.res = nn.Conv2d(decoder_filters[-5], 1, 1, stride=1, padding=0) + + self._initialize_weights() + + encoder = se_resnext50_32x4d(pretrained=pretrained) + + # conv1_new = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + # _w = encoder.layer0.conv1.state_dict() + # _w['weight'] = torch.cat([0.5 * _w['weight'], 0.5 * _w['weight']], 1) + # conv1_new.load_state_dict(_w) + self.conv1 = nn.Sequential(encoder.layer0.conv1, encoder.layer0.bn1, encoder.layer0.relu1) #encoder.layer0.conv1 + self.conv2 = nn.Sequential(encoder.pool, encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + + + def forward(self, x): + batch_size, C, H, W = x.shape + + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4 + ], 1)) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3 + ], 1)) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2 + ], 1)) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1 + ], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + + return self.res(dec10) + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class SeResNext50_Unet_Double(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(SeResNext50_Unet_Double, self).__init__() + + encoder_filters = [64, 256, 512, 1024, 2048] + decoder_filters = np.asarray([64, 96, 128, 256, 512]) // 2 + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + + self.res = nn.Conv2d(decoder_filters[-5] * 2, 5, 1, stride=1, padding=0) + + self._initialize_weights() + + encoder = se_resnext50_32x4d(pretrained=pretrained) + + # conv1_new = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + # _w = encoder.layer0.conv1.state_dict() + # _w['weight'] = torch.cat([0.5 * _w['weight'], 0.5 * _w['weight']], 1) + # conv1_new.load_state_dict(_w) + self.conv1 = nn.Sequential(encoder.layer0.conv1, encoder.layer0.bn1, encoder.layer0.relu1) #encoder.layer0.conv1 + self.conv2 = nn.Sequential(encoder.pool, encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + + + def forward1(self, x): + batch_size, C, H, W = x.shape + + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4 + ], 1)) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3 + ], 1)) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2 + ], 1)) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1 + ], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + + return dec10 + + + def forward(self, x): + + dec10_0 = self.forward1(x[:, :3, :, :]) + dec10_1 = self.forward1(x[:, 3:, :, :]) + + dec10 = torch.cat([dec10_0, dec10_1], 1) + + return self.res(dec10) + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class Dpn92_Unet_Loc(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(Dpn92_Unet_Loc, self).__init__() + + encoder_filters = [64, 336, 704, 1552, 2688] + decoder_filters = np.asarray([64, 96, 128, 256, 512]) // 2 + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = nn.Sequential(ConvRelu(decoder_filters[-1]+encoder_filters[-2], decoder_filters[-1]), SCSEModule(decoder_filters[-1], reduction=16, concat=True)) + self.conv7 = ConvRelu(decoder_filters[-1] * 2, decoder_filters[-2]) + self.conv7_2 = nn.Sequential(ConvRelu(decoder_filters[-2]+encoder_filters[-3], decoder_filters[-2]), SCSEModule(decoder_filters[-2], reduction=16, concat=True)) + self.conv8 = ConvRelu(decoder_filters[-2] * 2, decoder_filters[-3]) + self.conv8_2 = nn.Sequential(ConvRelu(decoder_filters[-3]+encoder_filters[-4], decoder_filters[-3]), SCSEModule(decoder_filters[-3], reduction=16, concat=True)) + self.conv9 = ConvRelu(decoder_filters[-3] * 2, decoder_filters[-4]) + self.conv9_2 = nn.Sequential(ConvRelu(decoder_filters[-4]+encoder_filters[-5], decoder_filters[-4]), SCSEModule(decoder_filters[-4], reduction=16, concat=True)) + self.conv10 = ConvRelu(decoder_filters[-4] * 2, decoder_filters[-5]) + + self.res = nn.Conv2d(decoder_filters[-5], 1, 1, stride=1, padding=0) + + self._initialize_weights() + + encoder = dpn92(pretrained=pretrained) + + # conv1_new = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + # _w = encoder.blocks['conv1_1'].conv.state_dict() + # _w['weight'] = torch.cat([0.5 * _w['weight'], 0.5 * _w['weight']], 1) + # conv1_new.load_state_dict(_w) + + self.conv1 = nn.Sequential( + encoder.blocks['conv1_1'].conv, # conv + encoder.blocks['conv1_1'].bn, # bn + encoder.blocks['conv1_1'].act, # relu + ) + self.conv2 = nn.Sequential( + encoder.blocks['conv1_1'].pool, # maxpool + *[b for k, b in encoder.blocks.items() if k.startswith('conv2_')] + ) + self.conv3 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv3_')]) + self.conv4 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv4_')]) + self.conv5 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv5_')]) + + def forward(self, x): + batch_size, C, H, W = x.shape + + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + + enc1 = (torch.cat(enc1, dim=1) if isinstance(enc1, tuple) else enc1) + enc2 = (torch.cat(enc2, dim=1) if isinstance(enc2, tuple) else enc2) + enc3 = (torch.cat(enc3, dim=1) if isinstance(enc3, tuple) else enc3) + enc4 = (torch.cat(enc4, dim=1) if isinstance(enc4, tuple) else enc4) + enc5 = (torch.cat(enc5, dim=1) if isinstance(enc5, tuple) else enc5) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4], 1)) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3], 1)) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2], 1)) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + + return self.res(dec10) + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class Res34_Unet_Single(nn.Module): + def __init__(self, pretrained=True, **kwargs): + super(Res34_Unet_Single, self).__init__() + + encoder_filters = [64, 64, 128, 256, 512] + decoder_filters = np.asarray([48, 64, 96, 160, 320]) + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + self.res = nn.Conv2d(decoder_filters[-5], 5, 1, stride=1, padding=0) + + self._initialize_weights() + + encoder = resnet34(pretrained=pretrained) + tmp_conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + self.conv1 = nn.Sequential( + #encoder.conv1, + tmp_conv1, + encoder.bn1, + encoder.relu) + self.conv2 = nn.Sequential( + encoder.maxpool, + encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + + def forward(self, x): + batch_size, C, H, W = x.shape + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4 + ], 1)) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3 + ], 1)) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2 + ], 1)) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1 + ], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + + return self.res(dec10) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class Dpn92_Unet_Double(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(Dpn92_Unet_Double, self).__init__() + + encoder_filters = [64, 336, 704, 1552, 2688] + decoder_filters = np.asarray([64, 96, 128, 256, 512]) // 2 + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = nn.Sequential(ConvRelu(decoder_filters[-1]+encoder_filters[-2], decoder_filters[-1]), SCSEModule(decoder_filters[-1], reduction=16, concat=True)) + self.conv7 = ConvRelu(decoder_filters[-1] * 2, decoder_filters[-2]) + self.conv7_2 = nn.Sequential(ConvRelu(decoder_filters[-2]+encoder_filters[-3], decoder_filters[-2]), SCSEModule(decoder_filters[-2], reduction=16, concat=True)) + self.conv8 = ConvRelu(decoder_filters[-2] * 2, decoder_filters[-3]) + self.conv8_2 = nn.Sequential(ConvRelu(decoder_filters[-3]+encoder_filters[-4], decoder_filters[-3]), SCSEModule(decoder_filters[-3], reduction=16, concat=True)) + self.conv9 = ConvRelu(decoder_filters[-3] * 2, decoder_filters[-4]) + self.conv9_2 = nn.Sequential(ConvRelu(decoder_filters[-4]+encoder_filters[-5], decoder_filters[-4]), SCSEModule(decoder_filters[-4], reduction=16, concat=True)) + self.conv10 = ConvRelu(decoder_filters[-4] * 2, decoder_filters[-5]) + + self.res = nn.Conv2d(decoder_filters[-5] * 2, 5, 1, stride=1, padding=0) + + self._initialize_weights() + + encoder = dpn92(pretrained=pretrained) + + # conv1_new = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + # _w = encoder.blocks['conv1_1'].conv.state_dict() + # _w['weight'] = torch.cat([0.5 * _w['weight'], 0.5 * _w['weight']], 1) + # conv1_new.load_state_dict(_w) + + self.conv1 = nn.Sequential( + encoder.blocks['conv1_1'].conv, # conv + encoder.blocks['conv1_1'].bn, # bn + encoder.blocks['conv1_1'].act, # relu + ) + self.conv2 = nn.Sequential( + encoder.blocks['conv1_1'].pool, # maxpool + *[b for k, b in encoder.blocks.items() if k.startswith('conv2_')] + ) + self.conv3 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv3_')]) + self.conv4 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv4_')]) + self.conv5 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv5_')]) + + + def forward1(self, x): + batch_size, C, H, W = x.shape + + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + + enc1 = (torch.cat(enc1, dim=1) if isinstance(enc1, tuple) else enc1) + enc2 = (torch.cat(enc2, dim=1) if isinstance(enc2, tuple) else enc2) + enc3 = (torch.cat(enc3, dim=1) if isinstance(enc3, tuple) else enc3) + enc4 = (torch.cat(enc4, dim=1) if isinstance(enc4, tuple) else enc4) + enc5 = (torch.cat(enc5, dim=1) if isinstance(enc5, tuple) else enc5) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4], 1)) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3], 1)) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2], 1)) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + + return dec10 + + + def forward(self, x): + + dec10_0 = self.forward1(x[:, :3, :, :]) + dec10_1 = self.forward1(x[:, 3:, :, :]) + + dec10 = torch.cat([dec10_0, dec10_1], 1) + + return self.res(dec10) + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class Res34_Unet_Loc(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(Res34_Unet_Loc, self).__init__() + + encoder_filters = [64, 64, 128, 256, 512] + decoder_filters = np.asarray([48, 64, 96, 160, 320]) + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + self.res = nn.Conv2d(decoder_filters[-5], 1, 1, stride=1, padding=0) + + self._initialize_weights() + + encoder = resnet34(pretrained=pretrained) + self.conv1 = nn.Sequential( + encoder.conv1, + encoder.bn1, + encoder.relu) + self.conv2 = nn.Sequential( + encoder.maxpool, + encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + + def forward(self, x): + batch_size, C, H, W = x.shape + + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4 + ], 1)) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3 + ], 1)) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2 + ], 1)) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1 + ], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + + return self.res(dec10) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class Res34_Unet_Double(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(Res34_Unet_Double, self).__init__() + + encoder_filters = [64, 64, 128, 256, 512] + decoder_filters = np.asarray([48, 64, 96, 160, 320]) + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + self.res = nn.Conv2d(decoder_filters[-5] * 2, 5, 1, stride=1, padding=0) + #self.res = nn.Conv2d(decoder_filters[-5], 5, 1, stride=1, padding=0) + self._initialize_weights() + + encoder = resnet34(pretrained=pretrained) + self.conv0 = ConvRelu(6,3) + self.conv1 = nn.Sequential( + encoder.conv1, + encoder.bn1, + encoder.relu) + self.conv2 = nn.Sequential( + encoder.maxpool, + encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + self.sa = SpatialAttention(kernel_size=3) + + def forward1(self, x): + batch_size, C, H, W = x.shape + #x = self.conv0(x) + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4 + ], 1)) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3 + ], 1)) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2 + ], 1)) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1 + ], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + + return dec10 + + def forward(self, x): + dec10_0 = self.forward1(x[:, :3, :, :]) + dec10_1 = self.forward1(x[:, 3:, :, :]) + x = torch.cat([dec10_0, dec10_1], 1) + #dec10 = self.sa(dec10) * dec10 + #x = self.forward1(x) + return self.res(x) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + +class SeNet154_Unet_Loc(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(SeNet154_Unet_Loc, self).__init__() + + encoder_filters = [128, 256, 512, 1024, 2048] + decoder_filters = np.asarray([48, 64, 96, 160, 320]) + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + self.res = nn.Conv2d(decoder_filters[-5], 1, 1, stride=1, padding=0) + + self._initialize_weights() + + encoder = senet154(pretrained=pretrained) + + # conv1_new = nn.Conv2d(9, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + # _w = encoder.layer0.conv1.state_dict() + # _w['weight'] = torch.cat([0.8 * _w['weight'], 0.1 * _w['weight'], 0.1 * _w['weight']], 1) + # conv1_new.load_state_dict(_w) + self.conv1 = nn.Sequential(encoder.layer0.conv1, encoder.layer0.bn1, encoder.layer0.relu1, encoder.layer0.conv2, encoder.layer0.bn2, encoder.layer0.relu2, encoder.layer0.conv3, encoder.layer0.bn3, encoder.layer0.relu3) + self.conv2 = nn.Sequential(encoder.pool, encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + + + def forward(self, x): + batch_size, C, H, W = x.shape + + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4 + ], 1)) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3 + ], 1)) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2 + ], 1)) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1 + ], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + + return self.res(dec10) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + +class SeNet154_Unet_Double(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(SeNet154_Unet_Double, self).__init__() + + encoder_filters = [128, 256, 512, 1024, 2048] + decoder_filters = np.asarray([48, 64, 96, 160, 320]) + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + self.res = nn.Conv2d(decoder_filters[-5] * 2, 5, 1, stride=1, padding=0) + + self._initialize_weights() + + encoder = senet154(pretrained=pretrained) + + # conv1_new = nn.Conv2d(9, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + # _w = encoder.layer0.conv1.state_dict() + # _w['weight'] = torch.cat([0.8 * _w['weight'], 0.1 * _w['weight'], 0.1 * _w['weight']], 1) + # conv1_new.load_state_dict(_w) + self.conv1 = nn.Sequential(encoder.layer0.conv1, encoder.layer0.bn1, encoder.layer0.relu1, encoder.layer0.conv2, encoder.layer0.bn2, encoder.layer0.relu2, encoder.layer0.conv3, encoder.layer0.bn3, encoder.layer0.relu3) + self.conv2 = nn.Sequential(encoder.pool, encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + + def forward1(self, x): + batch_size, C, H, W = x.shape + + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4 + ], 1)) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3 + ], 1)) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2 + ], 1)) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1 + ], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + + return dec10 + + def forward(self, x): + + dec10_0 = self.forward1(x[:, :3, :, :]) + dec10_1 = self.forward1(x[:, 3:, :, :]) + + dec10 = torch.cat([dec10_0, dec10_1], 1) + + return self.res(dec10) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + assert kernel_size in (3, 7), 'kernel size must be 3 or 7' + padding = 3 if kernel_size == 7 else 1 + + self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.tanh = nn.Tanh() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.tanh(x) + +class Deeplabv3_Double(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(Deeplabv3_Double, self).__init__() + self.loc_output = smp.DeepLabV3Plus(encoder_name='resnet34', encoder_weights=None) + #self.conv1 = nn.Conv1d(in_channels=2, out_channels=5, kernel_size=3) + self.conv2 = ConvRelu(in_channels=2, out_channels=5) + self.sa = SpatialAttention(kernel_size=3) + + + def forward1(self, x): + batch_size, C, H, W = x.shape + x = self.loc_output(x) + #x = self.sa(x) * x + return x + + def forward(self, x): + dec10_0 = self.forward1(x[:, :3, :, :]) + dec10_1 = self.forward1(x[:, 3:, :, :]) + x = torch.cat([dec10_0, dec10_1], 1) + #x = self.sa(x) * x + x = self.conv2(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + +class Res34_Unet_Double_Modified(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(Res34_Unet_Double_Modified, self).__init__() + + encoder_filters = [64, 64, 128, 256, 512] + decoder_filters = np.asarray([48, 64, 96, 160, 320]) + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + self.conv1d_1 = nn.Conv2d(in_channels=encoder_filters[0]*2, out_channels=encoder_filters[0], kernel_size=1) + self.conv1d_2 = nn.Conv2d(in_channels=encoder_filters[1]*3, out_channels=encoder_filters[1], kernel_size=1) + self.conv1d_3 = nn.Conv2d(in_channels=encoder_filters[2]*3, out_channels=encoder_filters[2], kernel_size=1) + self.conv1d_4 = nn.Conv2d(in_channels=encoder_filters[3]*3, out_channels=encoder_filters[3], kernel_size=1) + self.conv1d_5 = nn.Conv2d(in_channels=encoder_filters[4]*3, out_channels=encoder_filters[4], kernel_size=1) + self.conv1d_6 = nn.Conv2d(in_channels=decoder_filters[-1]*3, out_channels=decoder_filters[-1], kernel_size=1) + self.conv1d_7 = nn.Conv2d(in_channels=decoder_filters[-2]*3, out_channels=decoder_filters[-2], kernel_size=1) + self.conv1d_8 = nn.Conv2d(in_channels=decoder_filters[-3]*3, out_channels=decoder_filters[-3], kernel_size=1) + self.conv1d_9 = nn.Conv2d(in_channels=decoder_filters[-4]*3, out_channels=decoder_filters[-4], kernel_size=1) + + self.res = nn.Conv2d(decoder_filters[-5] * 3, 5, 1, stride=1, padding=0) + + + self._initialize_weights() + + encoder = resnet34(pretrained=pretrained) + self.conv1 = nn.Sequential( + encoder.conv1, + encoder.bn1, + encoder.relu) + self.conv2 = nn.Sequential( + encoder.maxpool, + encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + + def forward1(self, x): + batch_size, C, H, W = x.shape + + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4 + ], 1)) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3 + ], 1)) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2 + ], 1)) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1 + ], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + + return enc1, enc2, enc3, enc4, enc5, dec6, dec7, dec8, dec9, dec10 + + + def forward2(self, enc1_1, enc2_1, enc3_1, enc4_1, enc5_1, dec6_1, dec7_1, dec8_1, dec9_1, dec10_1,enc1_2, enc2_2, enc3_2, enc4_2, enc5_2, dec6_2, dec7_2, dec8_2, dec9_2, dec10_2): + enc1 = torch.cat([enc1_1, enc1_2], 1) + enc1 = self.conv1d_1(enc1) + enc2 = self.conv2(enc1) + enc2 = torch.cat([enc2, enc2_1, enc2_2], 1) + enc2 = self.conv1d_2(enc2) + enc3 = self.conv3(enc2) + enc3 = torch.cat([enc3, enc3_1, enc3_2], 1) + enc3 = self.conv1d_3(enc3) + enc4 = self.conv4(enc3) + enc4 = torch.cat([enc4, enc4_1, enc4_2], 1) + enc4 = self.conv1d_4(enc4) + enc5 = self.conv5(enc4) + enc5 = torch.cat([enc5, enc5_1, enc5_2], 1) + enc5 = self.conv1d_5(enc5) + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4 + ], 1)) + dec6 = torch.cat([dec6, dec6_1, dec6_2], 1) + dec6 = self.conv1d_6(dec6) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3 + ], 1)) + dec7 = torch.cat([dec7, dec7_1, dec7_2], 1) + dec7 = self.conv1d_7(dec7) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2 + ], 1)) + dec8 = torch.cat([dec8, dec8_1, dec8_2], 1) + dec8 = self.conv1d_8(dec8) + + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1 + ], 1)) + dec9 = torch.cat([dec9, dec9_1, dec9_2], 1) + dec9 = self.conv1d_9(dec9) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + dec10 = torch.cat([dec10, dec10_1, dec10_2], 1) + return dec10 + + + def forward(self, x): + enc1_1, enc2_1, enc3_1, enc4_1, enc5_1, dec6_1, dec7_1, dec8_1, dec9_1, dec10_1 = self.forward1(x[:, :3, :, :]) + enc1_2, enc2_2, enc3_2, enc4_2, enc5_2, dec6_2, dec7_2, dec8_2, dec9_2, dec10_2 = self.forward1(x[:, 3:, :, :]) + dec10 = self.forward2(enc1_1, enc2_1, enc3_1, enc4_1, enc5_1, dec6_1, dec7_1, dec8_1, dec9_1, dec10_1,enc1_2, enc2_2, enc3_2, enc4_2, enc5_2, dec6_2, dec7_2, dec8_2, dec9_2, dec10_2) + # dec10 = torch.cat([dec10_1, dec10_2], 1) + return self.res(dec10) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class ChannelAttention(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(ChannelAttention, self).__init__() + padding = 3 if kernel_size == 7 else 1 + self.conv1 = nn.Conv2d(in_channels*2, out_channels, kernel_size, padding=padding, bias=False) + self.relu = nn.ReLU() +# changed tanh to relu for two_transformer exp + def forward(self, x_1, x_2): + x = torch.cat([x_1, x_2], dim=1) + x = self.conv1(x) + return self.relu(x) + +class ChannelAttention_OnBottle(nn.Module): + def __init__(self, in_planes, ratio=16, att_type='max'): + super(ChannelAttention_OnBottle, self).__init__() + self.att_type = att_type + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + # self.min_pool = nn.AdaptiveMinPool2d(1) + + self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) + self.relu = nn.ReLU() + self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) + self.sigmoid = nn.Sigmoid() + self.fc3 = nn.Linear(64, 512) + self.fc4 = nn.Linear(96, 512) + + def forward(self, x): + if self.att_type == 'max': + out = self.fc2(self.relu(self.fc1(self.max_pool(x)))) + elif self.att_type == 'max_avg': + max_out = self.relu(self.fc1(self.max_pool(x))) + avg_out = self.relu(self.fc1(self.avg_pool(x))) + out = torch.cat([max_out, avg_out], 1).squeeze() + out = self.fc3(out).unsqueeze(0).unsqueeze(2).unsqueeze(3) + elif self.att_type == 'avg_max_min': + avg_out = self.relu(self.fc1(self.avg_pool(x))) + # min_out = self.relu(self.fc1(self.min_pool(x))) + max_out = self.relu(self.fc1(self.max_pool(x))) + out = torch.cat([avg_out, min_out, max_out], 1).squeeze() + out = self.fc4(out).unsqueeze(0).unsqueeze(2).unsqueeze(3) + return self.relu(out) + +''' +class Attention_block(nn.Module): + def __init__(self,F_g,F_l,F_int): + super(Attention_block,self).__init__() + self.W_g = nn.Sequential( + nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(F_int) + ) + self.W_x = nn.Sequential( + nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(F_int) + ) + self.psi = nn.Sequential( + nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), + nn.BatchNorm2d(1), + nn.Sigmoid() + ) + self.relu = nn.ReLU(inplace=True) + + def forward(self,g,x): + g1 = self.W_g(g) + x1 = self.W_x(x) + psi = self.relu(g1+x1) + psi = self.psi(psi) + return x*psi + +# Attention for bottleneck layers +class ChannelAttention_OnBottle(nn.Module): + def __init__(self, in_planes, ratio=16, att_type='max'): + super(ChannelAttention_OnBottle, self).__init__() + self.att_type = att_type + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + # self.min_pool = nn.AdaptiveMinPool2d(1) + + self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) + self.relu = nn.ReLU() + self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) + self.sigmoid = nn.Sigmoid() + self.fc3 = nn.Linear(64, 512) + self.fc4 = nn.Linear(96, 512) + + def forward(self, x): + if self.att_type == 'max': + out = self.fc2(self.relu(self.fc1(self.max_pool(x)))) + elif self.att_type == 'max_avg': + max_out = self.relu(self.fc1(self.max_pool(x))) + avg_out = self.relu(self.fc1(self.avg_pool(x))) + out = torch.cat([max_out, avg_out], 1).squeeze() + out = self.fc3(out).unsqueeze(0).unsqueeze(2).unsqueeze(3) + elif self.att_type == 'avg_max_min': + avg_out = self.relu(self.fc1(self.avg_pool(x))) + # min_out = self.relu(self.fc1(self.min_pool(x))) + max_out = self.relu(self.fc1(self.max_pool(x))) + out = torch.cat([avg_out, min_out, max_out], 1).squeeze() + out = self.fc4(out).unsqueeze(0).unsqueeze(2).unsqueeze(3) + return self.relu(out) + +class UNet_Change_Transformer(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(UNet_Change_Transformer, self).__init__() + + encoder_filters = [64, 64, 128, 256, 512] + decoder_filters = np.asarray([48, 64, 96, 160, 320]) + + self.encoder_filters = [64, 64, 128, 256, 512] + self.decoder_filters = np.asarray([48, 64, 96, 160, 320]) + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + self.res = nn.Conv2d(decoder_filters[-5], 5, 1, stride=1, padding=0) + + self._initialize_weights() + + encoder = resnet34(pretrained=pretrained) + self.conv1 = nn.Sequential( + encoder.conv1, + encoder.bn1, + encoder.relu) + self.conv2 = nn.Sequential( + encoder.maxpool, + encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + + self.ca_skip_5 = ChannelAttention(encoder_filters[-1], encoder_filters[-1]) + self.ca_skip_4 = ChannelAttention(encoder_filters[-2], encoder_filters[-2]) + self.ca_skip_3 = ChannelAttention(encoder_filters[-3], encoder_filters[-3]) + self.ca_skip_2 = ChannelAttention(encoder_filters[-4], encoder_filters[-4]) + self.ca_skip_1 = ChannelAttention(encoder_filters[-5], encoder_filters[-5]) + + self.ca_bottle_max = ChannelAttention_OnBottle(512, att_type='max') + self.ca_bottle_avg_min = ChannelAttention_OnBottle(512, att_type='max_avg') + + + in_dim = 1024 + mlp_dim = 2*in_dim + enc_depth = 2 + dim_head = 64 + self.transformer = Transformer(dim=in_dim, depth=2, heads=8, + dim_head=dim_head, + mlp_dim=mlp_dim, dropout=0) + + + + + def forward(self, x): + + # Encoder 1 + x_1 = x[:, :3, :, :] + enc1_1 = self.conv1(x_1) + enc2_1 = self.conv2(enc1_1) + enc3_1 = self.conv3(enc2_1) + enc4_1 = self.conv4(enc3_1) + enc5_1 = self.conv5(enc4_1) + + # Encoder 2 + x_2 = x[:, 3:, :, :] + enc1_2 = self.conv1(x_2) + enc2_2 = self.conv2(enc1_2) + enc3_2 = self.conv3(enc2_2) + enc4_2 = self.conv4(enc3_2) + enc5_2 = self.conv5(enc4_2) + + # Bottleneck + # enc5_1 = (self.ca_bottle_max(enc5_1)*enc5_1) + # enc5_2 = (self.ca_bottle_max(enc5_2)*enc5_2) + + # try dot product as attention?? + # enc5 = torch.einsum('bcij,bcij->bcij' , enc5_1, enc5_2) + # enc5 = self.ca_bottle_avg_min(enc5)*enc5 + + enc5 = self.ca_skip_5(enc5_1,enc5_2) + + # B_, C_, H_, W_ = enc5.shape + # enc5_i = enc5.view([B_, C_, H_*W_]) + # enc5_i = self.transformer(enc5_i) + # # print(enc5.shape) + # enc5_i = enc5_i.view([B_, C_, H_, W_]) + # enc5 = self.ca_skip_5(enc5_i,enc5) + + # Decoder + enc4 = self.ca_skip_4(enc4_1, enc4_2) + # enc4 = attention_block(enc4_1, enc4_2, self.encoder_filters[-2]) + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4 + ], 1)) + + enc3 = self.ca_skip_3(enc3_1, enc3_2) + # enc3 = attention_block(enc3_1, enc3_2, self.encoder_filters[-3]) + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3 + ], 1)) + + enc2 = self.ca_skip_2(enc2_1, enc2_2) + # enc2 = attention_block(enc2_1, enc2_2, self.encoder_filters[-4]) + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2 + ], 1)) + + enc1 = self.ca_skip_2(enc1_1, enc1_2) + # enc1 = attention_block(enc1_1, enc1_2, self.encoder_filters[-5]) + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, + enc1 + ], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + out = self.res(dec10) + return out + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() +''' + +class UNet_Change_Transformer_BiT(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(UNet_Change_Transformer_BiT, self).__init__() + + print("using UNet_Change_Transformer !!!!") + + encoder_filters = [64, 64, 128, 256, 512] + decoder_filters = np.asarray([48, 64, 96, 128, 320]) + + self.encoder_filters = [64, 64, 128, 256, 512] + self.decoder_filters = np.asarray([48, 64, 96, 160, 320]) + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2] , decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3] , decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4] , decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5] , decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + self.res = nn.Conv2d(decoder_filters[-5], 5, 1, stride=1, padding=0) + + + self.conv5_2 = nn.Conv2d(1024, 512, kernel_size=5, stride=1, padding=2) + self._initialize_weights() + + encoder = resnet34(pretrained=pretrained) + self.conv1 = nn.Sequential( + encoder.conv1, + encoder.bn1, + encoder.relu) + self.conv2 = nn.Sequential( + encoder.maxpool, + encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + + self.ca_skip_5 = ChannelAttention(encoder_filters[-1], encoder_filters[-1]) + self.ca_skip_4 = ChannelAttention(encoder_filters[-2], encoder_filters[-2]) + self.ca_skip_3 = ChannelAttention(encoder_filters[-3], encoder_filters[-3]) + self.ca_skip_2 = ChannelAttention(encoder_filters[-4], encoder_filters[-4]) + self.ca_skip_1 = ChannelAttention(encoder_filters[-5], encoder_filters[-5]) + + self.ca_bottle_max = ChannelAttention_OnBottle(512, att_type='max') + self.sigmoid = nn.Sigmoid() + + dim = 64 + dim2 = 256 + dim3 = 1024 + heads = 4 + # self.dec_depth = 3 + self.enc_depth = 8 + decoder_softmax = True + self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=heads, + dim_head=dim, mlp_dim=dim**2, dropout=0.05) + # self.transformer2 = Transformer(dim=dim2, depth=self.enc_depth, heads=heads, + # dim_head=dim2, mlp_dim=dim2, dropout=0.05) + # self.transformer3 = Transformer(dim=dim3, depth=self.enc_depth, heads=heads//2, + # dim_head=dim3, mlp_dim=dim3, dropout=0.05) + # self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth, + # heads=heads, dim_head=dim, mlp_dim=dim**2, dropout=0, + # softmax=decoder_softmax) + # self.transformer_decoder2 = TransformerDecoder(dim=dim2, depth=self.dec_depth, + # heads=heads, dim_head=dim2, mlp_dim=dim2, dropout=0, + # softmax=decoder_softmax) + # self.transformer_decoder3 = TransformerDecoder(dim=dim3, depth=self.dec_depth, + # heads=heads//2, dim_head=dim3, mlp_dim=dim3//2, dropout=0, + # softmax=decoder_softmax) + + + def forward(self, x): + # Encoder 1 + x_1 = x[:, :3, :, :] + enc1_1 = self.conv1(x_1) + enc2_1 = self.conv2(enc1_1) + enc3_1 = self.conv3(enc2_1) + enc4_1 = self.conv4(enc3_1) + enc5_1 = self.conv5(enc4_1) + + # Encoder 2 + x_2 = x[:, 3:, :, :] + enc1_2 = self.conv1(x_2) + enc2_2 = self.conv2(enc1_2) + enc3_2 = self.conv3(enc2_2) + enc4_2 = self.conv4(enc3_2) + enc5_2 = self.conv5(enc4_2) + + # # Bottleneck + enc5_1 = (self.ca_bottle_max(enc5_1)*enc5_1) + enc5_2 = (self.ca_bottle_max(enc5_2)*enc5_2) + + enc5 = self.ca_skip_5(enc5_1,enc5_2) + B_, C_, H_, W_ = enc5.shape + enc5_i = enc5.view([B_, C_, H_*W_]) + enc5_i = self.transformer(enc5_i) + enc5_i = enc5_i.view([B_, C_, H_, W_]) + enc5 = self.ca_skip_5(enc5_i,enc5) + + enc4 = self.ca_skip_4(enc4_1, enc4_2) + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4], 1)) + # dec6 = self.conv6_2(torch.cat([dec6, enc4_1, enc4_2], 1)) + + enc3 = self.ca_skip_3(enc3_1, enc3_2) + # B_, C_, H_, W_ = enc3.shape + # enc3_i = enc3.view([B_, C_, H_*W_]) + # enc3_i = self.transformer3(enc3_i) + # # enc4_i = enc4_i.view([B_, C_, H_, W_]) + # # enc4 = self.ca_skip_4(enc4_i,enc4) + # enc3 = enc3.view([B_, C_, H_*W_]) + # enc3 = self.transformer_decoder3(enc3_i, enc3) + # enc3 = enc3.view([B_, C_, H_, W_]) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3], 1)) + + enc2 = self.ca_skip_2(enc2_1, enc2_2) + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2], 1)) + + enc1 = self.ca_skip_2(enc1_1, enc1_2) + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, enc1], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + out = self.res(dec10) + + return out + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + +class ResNet(torch.nn.Module): + def __init__(self, input_nc, output_nc, + resnet_stages_num=5, backbone='resnet18', + output_sigmoid=False, if_upsample_2x=True): + """ + In the constructor we instantiate two nn.Linear modules and assign them as + member variables. + """ + super(ResNet, self).__init__() + expand = 1 + if backbone == 'resnet18': + self.resnet = bitmodule.resnet18(pretrained=True, replace_stride_with_dilation=[False,True,True]) + elif backbone == 'resnet34': + self.resnet = bitmodule.resnet34(pretrained=True, replace_stride_with_dilation=[False,True,True]) + else: + raise NotImplementedError + self.relu = nn.ReLU() + self.upsamplex2 = nn.Upsample(scale_factor=2) + self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear') + + self.classifier = TwoLayerConv2d(in_channels=32, out_channels=output_nc) + + self.resnet_stages_num = resnet_stages_num + + self.if_upsample_2x = if_upsample_2x + if self.resnet_stages_num == 5: + layers = 512 * expand + elif self.resnet_stages_num == 4: + layers = 256 * expand + elif self.resnet_stages_num == 3: + layers = 128 * expand + else: + raise NotImplementedError + self.conv_pred = nn.Conv2d(layers, 32, kernel_size=3, padding=1) + + self.output_sigmoid = output_sigmoid + self.sigmoid = nn.Sigmoid() + + def forward(self, x1, x2): + x1 = self.forward_single(x1) + x2 = self.forward_single(x2) + x = torch.abs(x1 - x2) + if not self.if_upsample_2x: + x = self.upsamplex2(x) + x = self.upsamplex4(x) + x = self.classifier(x) + + if self.output_sigmoid: + x = self.sigmoid(x) + return x + + def forward_single(self, x): + # resnet layers + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x = self.resnet.relu(x) + x = self.resnet.maxpool(x) + + x_4 = self.resnet.layer1(x) # 1/4, in=64, out=64 + x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128 + + if self.resnet_stages_num > 3: + x_8 = self.resnet.layer3(x_8) # 1/8, in=128, out=256 + + if self.resnet_stages_num == 5: + x_8 = self.resnet.layer4(x_8) # 1/32, in=256, out=512 + elif self.resnet_stages_num > 5: + raise NotImplementedError + + if self.if_upsample_2x: + x = self.upsamplex2(x_8) + else: + x = x_8 + # output layers + x = self.conv_pred(x) + return x + +class BASE_Transformer(ResNet): + """ + Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN + """ + def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5, + token_len=4, token_trans=True, + enc_depth=1, dec_depth=1, + dim_head=64, decoder_dim_head=64, + tokenizer=True, if_upsample_2x=True, + pool_mode='max', pool_size=2, + backbone='resnet18', + decoder_softmax=True, with_decoder_pos=None, + with_decoder=True): + super(BASE_Transformer, self).__init__(input_nc, output_nc,backbone=backbone, + resnet_stages_num=resnet_stages_num, + if_upsample_2x=if_upsample_2x, + ) + + print("using BiT Transformer !!!!") + + self.token_len = token_len + self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1, + padding=0, bias=False) + self.tokenizer = tokenizer + if not self.tokenizer: + # if not use tokenzier,then downsample the feature map into a certain size + self.pooling_size = pool_size + self.pool_mode = pool_mode + self.token_len = self.pooling_size * self.pooling_size + + self.token_trans = token_trans + self.with_decoder = with_decoder + dim = 32 + mlp_dim = 2*dim + + self.with_pos = with_pos + if with_pos is 'learned': + self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, 32)) + decoder_pos_size = 256//4 + self.with_decoder_pos = with_decoder_pos + if self.with_decoder_pos == 'learned': + self.pos_embedding_decoder =nn.Parameter(torch.randn(1, 32, + decoder_pos_size, + decoder_pos_size)) + self.enc_depth = enc_depth + self.dec_depth = dec_depth + self.dim_head = dim_head + self.decoder_dim_head = decoder_dim_head + self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=8, + dim_head=self.dim_head, + mlp_dim=mlp_dim, dropout=0) + self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth, + heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0, + softmax=decoder_softmax) + + def _forward_semantic_tokens(self, x): + b, c, h, w = x.shape + spatial_attention = self.conv_a(x) + spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() + spatial_attention = torch.softmax(spatial_attention, dim=-1) + x = x.view([b, c, -1]).contiguous() + tokens = torch.einsum('bln,bcn->blc', spatial_attention, x) + + return tokens + + def _forward_reshape_tokens(self, x): + # b,c,h,w = x.shape + if self.pool_mode is 'max': + x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size]) + elif self.pool_mode is 'ave': + x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size]) + else: + x = x + tokens = rearrange(x, 'b c h w -> b (h w) c') + return tokens + + def _forward_transformer(self, x): + if self.with_pos: + x += self.pos_embedding + x = self.transformer(x) + return x + + def _forward_transformer_decoder(self, x, m): + b, c, h, w = x.shape + if self.with_decoder_pos == 'fix': + x = x + self.pos_embedding_decoder + elif self.with_decoder_pos == 'learned': + x = x + self.pos_embedding_decoder + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.transformer_decoder(x, m) + x = rearrange(x, 'b (h w) c -> b c h w', h=h) + return x + + def _forward_simple_decoder(self, x, m): + b, c, h, w = x.shape + b, l, c = m.shape + m = m.expand([h,w,b,l,c]) + m = rearrange(m, 'h w b l c -> l b c h w') + m = m.sum(0) + x = x + m + return x + + def forward(self, x): + # forward backbone resnet + x1 = x[:, :3, :, :] + x2 = x[:, 3:, :, :] + x1 = self.forward_single(x1) + x2 = self.forward_single(x2) + # forward tokenzier + if self.tokenizer: + token1 = self._forward_semantic_tokens(x1) + token2 = self._forward_semantic_tokens(x2) + else: + token1 = self._forward_reshape_tokens(x1) + token2 = self._forward_reshape_tokens(x2) + # forward transformer encoder + if self.token_trans: + self.tokens_ = torch.cat([token1, token2], dim=1) + self.tokens = self._forward_transformer(self.tokens_) + token1, token2 = self.tokens.chunk(2, dim=1) + + # forward transformer decoder + if self.with_decoder: + x1 = self._forward_transformer_decoder(x1, token1) + x2 = self._forward_transformer_decoder(x2, token2) + else: + x1 = self._forward_simple_decoder(x1, token1) + x2 = self._forward_simple_decoder(x2, token2) + # feature differencing + x = torch.abs(x1 - x2) + if not self.if_upsample_2x: + x = self.upsamplex2(x) + x = self.upsamplex4(x) + # forward small cnn + x = self.classifier(x) + if self.output_sigmoid: + x = self.sigmoid(x) + return x + +class TwoLayerConv2d(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size=3): + super().__init__(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, + padding=kernel_size // 2, stride=1, bias=False), + nn.BatchNorm2d(in_channels), + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + padding=kernel_size // 2, stride=1) + ) + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + +class Residual2(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, x2, **kwargs): + return self.fn(x, x2, **kwargs) + x + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class PreNorm2(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, x2, **kwargs): + return self.fn(self.norm(x), self.norm(x2), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Cross_Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim ** -0.5 + + self.softmax = softmax + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, m, mask = None): + b, n, _, h = *x.shape, self.heads + q = self.to_q(x) + k = self.to_k(m) + v = self.to_v(m) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v]) + + dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value = True) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, mask_value) + del mask + + if self.softmax: + attn = dots.softmax(dim=-1) + else: + attn = dots + + out = torch.einsum('bhij,bhjd->bhid', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + # vis_tmp2(out) + return out + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim ** -0.5 + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask = None): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value = True) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, mask_value) + del mask + + attn = dots.softmax(dim=-1) + + out = torch.einsum('bhij,bhjd->bhid', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + ])) + def forward(self, x, mask = None): + for attn, ff in self.layers: + x = attn(x, mask = mask) + x = ff(x) + return x + +class TransformerDecoder(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads, + dim_head = dim_head, dropout = dropout, + softmax=softmax))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) + ])) + def forward(self, x, m, mask = None): + """target(query), memory""" + for attn, ff in self.layers: + x = attn(x, m, mask = mask) + x = ff(x) + return x + + +class UNet_Change_Two_Transformer(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(UNet_Change_Two_Transformer, self).__init__() + + encoder_filters = [64, 64, 128, 256, 512] + decoder_filters = np.asarray([48, 64, 96, 128, 320]) + + self.encoder_filters = [64, 64, 128, 256, 512] + self.decoder_filters = np.asarray([48, 64, 96, 160, 320]) + + self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) + self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2]*2, decoder_filters[-1]) + self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) + self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3]*2, decoder_filters[-2]) + self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) + self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4] , decoder_filters[-3]) + self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) + self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5]*2 , decoder_filters[-4]) + self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) + + self.res = nn.Conv2d(decoder_filters[-5], 5, 1, stride=1, padding=0) + + self._initialize_weights() + + encoder = resnet34(pretrained=pretrained) + self.conv1 = nn.Sequential( + encoder.conv1, + encoder.bn1, + encoder.relu) + self.conv2 = nn.Sequential( + encoder.maxpool, + encoder.layer1) + self.conv3 = encoder.layer2 + self.conv4 = encoder.layer3 + self.conv5 = encoder.layer4 + + self.ca_skip_5 = ChannelAttention(encoder_filters[-1], encoder_filters[-1]) + self.ca_skip_4 = ChannelAttention(encoder_filters[-2], encoder_filters[-2]) + self.ca_skip_3 = ChannelAttention(encoder_filters[-3], encoder_filters[-3]) + self.ca_skip_2 = ChannelAttention(encoder_filters[-4], encoder_filters[-4]) + self.ca_skip_1 = ChannelAttention(encoder_filters[-5], encoder_filters[-5]) + + self.ca_bottle_max = ChannelAttention_OnBottle(512, att_type='max') + self.ca_bottle_avg_min = ChannelAttention_OnBottle(512, att_type='max_avg') + self.sigmoid = nn.Sigmoid() + self.linearb = nn.Linear(1024, 512) + + dim = 64 + dim2 = 4096 + mlp_dim = 2*dim + enc_depth = 2 + dim_head = 64 + decoder_dim_head = 64 + decoder_softmax = True + self.transformer = Transformer(dim=dim, depth=6, heads=4, + dim_head=dim_head, + mlp_dim=mlp_dim, dropout=0) + + self.transformer2 = Transformer(dim=dim2, depth=4, heads=2, # depth=6, heads=4 leads to GPU outage + dim_head=dim_head, + mlp_dim=dim2, dropout=0) + + self.classifier = TwoLayerConv2d(in_channels=512, out_channels=2) + self.convT = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1) + + + def forward(self, x): + # Encoder 1 + x_1 = x[:, :3, :, :] + enc1_1 = self.conv1(x_1) + enc2_1 = self.conv2(enc1_1) + enc3_1 = self.conv3(enc2_1) + enc4_1 = self.conv4(enc3_1) + enc5_1 = self.conv5(enc4_1) + + # Encoder 2 + x_2 = x[:, 3:, :, :] + enc1_2 = self.conv1(x_2) + enc2_2 = self.conv2(enc1_2) + enc3_2 = self.conv3(enc2_2) + enc4_2 = self.conv4(enc3_2) + enc5_2 = self.conv5(enc4_2) + + # Bottleneck + # enc5_1 = (self.ca_bottle_max(enc5_1)*enc5_1) + # enc5_2 = (self.ca_bottle_max(enc5_2)*enc5_2) + # enc5_c = self.ca_skip_5(enc5_1,enc5_2) + + ## run 1: updating channel attention + enc5 = self.ca_skip_5(enc5_1,enc5_2) + B_, C_, H_, W_ = enc5.shape + enc5_i = enc5.view([B_, C_, H_*W_]).contiguous() + + # enc5_diff = (enc5_1 - enc5_2) + # spatial_attention = enc5_diff.view([B_, C_, H_*W_]).contiguous() + # spatial_attention = torch.softmax(spatial_attention, dim=-1) + + # enc5_i = torch.einsum('bln,bln->bln', spatial_attention, enc5_i) + enc5_t = self.transformer(enc5_i) + enc5_t = enc5_t.view([B_, C_, H_, W_]).contiguous() + enc5 = self.ca_skip_5(enc5_t,enc5) + + # Decoder + # enc4 = self.ca_skip_4(enc4_1, enc4_2) + # B_, C_, H_, W_ = enc4.shape + # enc4 = enc4.view([B_, C_, H_*W_]).contiguous() + + # enc4_diff = (enc4_1 - enc4_2) + # spatial_attention = enc4_diff.view([B_, C_, H_*W_]).contiguous() + # spatial_attention = torch.softmax(spatial_attention, dim=-1) + + # enc4 = torch.einsum('bln,bln->bln', spatial_attention, enc4) + # enc4 = enc4.view([B_, C_, H_,W_]).contiguous() + + dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) + dec6 = self.conv6_2(torch.cat([dec6, enc4_1, enc4_2], 1)) + + ## run4: depth=2, heads=2 + # enc3 = self.ca_skip_3(enc3_1, enc3_2) + # B_, C_, H_, W_ = enc3.shape + # enc3_i = enc3.view([B_, C_, H_*W_]).contiguous() + + # # enc3_diff = (enc3_1 - enc3_2) + # # spatial_attention = enc3_diff.view([B_, C_, H_*W_]).contiguous() + # # spatial_attention = torch.softmax(spatial_attention, dim=-1) + + # # enc3 = torch.einsum('bln,bln->bln', spatial_attention, enc3) + # enc3_t = self.transformer2(enc3_i) + # enc3_t = enc3_t.view([B_, C_, H_, W_]).contiguous() + # enc3 = self.ca_skip_3(enc3, enc3_t) + + dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) + dec7 = self.conv7_2(torch.cat([dec7, enc3_1, enc3_2], 1)) + + ## run0: depth=2, heads=1 + enc2 = self.ca_skip_2(enc2_1, enc2_2) + B_, C_, H_, W_ = enc2.shape + enc2_i = enc2.view([B_, C_, H_*W_]).contiguous() + + # enc2_diff = (enc2_1 - enc2_2) + # spatial_attention = enc2_diff.view([B_, C_, H_*W_]).contiguous() + # spatial_attention = torch.softmax(spatial_attention, dim=-1) + + # enc2_i = torch.einsum('bln,bln->bln', spatial_attention, enc2_i) + enc2_t = self.transformer2(enc2_i) + enc2_t = enc2_t.view([B_, C_, H_,W_]).contiguous() + enc2 = self.ca_skip_2(enc2_t, enc2) + + dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) + dec8 = self.conv8_2(torch.cat([dec8, enc2], 1)) + + # enc1 = self.ca_skip_2(enc1_1, enc1_2) + dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) + dec9 = self.conv9_2(torch.cat([dec9, enc1_1, enc1_2], 1)) + + dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) + out = self.res(dec10) + + return out #, interim_out + + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + + +class ResNet_Encoder(torch.nn.Module): + def __init__(self, input_nc, output_nc, + resnet_stages_num=5, backbone='resnet18', + output_sigmoid=False, if_upsample_2x=True): + """ + In the constructor we instantiate two nn.Linear modules and assign them as + member variables. + """ + super(ResNet, self).__init__() + expand = 1 + if backbone == 'resnet18': + self.resnet = bitmodule.resnet18(pretrained=True, replace_stride_with_dilation=[False,True,True]) + elif backbone == 'resnet34': + self.resnet = bitmodule.resnet34(pretrained=True, replace_stride_with_dilation=[False,True,True]) + else: + raise NotImplementedError + self.relu = nn.ReLU() + self.upsamplex2 = nn.Upsample(scale_factor=2) + self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear') + + self.classifier = TwoLayerConv2d(in_channels=32, out_channels=output_nc) + + self.resnet_stages_num = resnet_stages_num + + self.if_upsample_2x = if_upsample_2x + if self.resnet_stages_num == 5: + layers = 512 * expand + elif self.resnet_stages_num == 4: + layers = 256 * expand + elif self.resnet_stages_num == 3: + layers = 128 * expand + else: + raise NotImplementedError + self.conv_pred = nn.Conv2d(layers, 32, kernel_size=3, padding=1) + + self.output_sigmoid = output_sigmoid + self.sigmoid = nn.Sigmoid() + + def forward(self, x1, x2): + x1 = self.forward_single(x1) + x2 = self.forward_single(x2) + x = torch.abs(x1 - x2) + if not self.if_upsample_2x: + x = self.upsamplex2(x) + x = self.upsamplex4(x) + x = self.classifier(x) + + if self.output_sigmoid: + x = self.sigmoid(x) + return x + + def forward_single(self, x): + # resnet layers + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x = self.resnet.relu(x) + x_64 = self.resnet.maxpool(x) + + x_64_2 = self.resnet.layer1(x_64) # 1/4, in=64, out=64 + x_128 = self.resnet.layer2(x_64_2) # 1/8, in=64, out=128 + + if self.resnet_stages_num > 3: + x_256 = self.resnet.layer3(x_128) # 1/8, in=128, out=256 + + if self.resnet_stages_num == 5: + x_512 = self.resnet.layer4(x_256) # 1/32, in=256, out=512 + elif self.resnet_stages_num > 5: + raise NotImplementedError + + if self.if_upsample_2x: + x = self.upsamplex2(x_512) + else: + x = x_512 + # output layers + x = self.conv_pred(x) + print("in forward singgleee") + return x, x_64, x_64_2, x_128, x_256, x_512 + + +class BASE_UNet_Transformer(ResNet): + """ + Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN + """ + def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5, + token_len=4, token_trans=True, + enc_depth=1, dec_depth=1, + dim_head=64, decoder_dim_head=64, + tokenizer=True, if_upsample_2x=True, + pool_mode='max', pool_size=2, + backbone='resnet18', + decoder_softmax=True, with_decoder_pos=None, + with_decoder=True): + super(BASE_UNet_Transformer, self).__init__(input_nc, output_nc,backbone=backbone, + resnet_stages_num=resnet_stages_num, + if_upsample_2x=if_upsample_2x, + ) + + print("using BiT Transformer !!!!") + + self.token_len = token_len + self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1, + padding=0, bias=False) + self.tokenizer = tokenizer + if not self.tokenizer: + # if not use tokenzier,then downsample the feature map into a certain size + self.pooling_size = pool_size + self.pool_mode = pool_mode + self.token_len = self.pooling_size * self.pooling_size + + self.token_trans = token_trans + self.with_decoder = with_decoder + dim = 32 + mlp_dim = 2*dim + + self.with_pos = with_pos + if with_pos is 'learned': + self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, 32)) + decoder_pos_size = 256//4 + self.with_decoder_pos = with_decoder_pos + if self.with_decoder_pos == 'learned': + self.pos_embedding_decoder =nn.Parameter(torch.randn(1, 32, + decoder_pos_size, + decoder_pos_size)) + self.enc_depth = enc_depth + self.dec_depth = dec_depth + self.dim_head = dim_head + self.decoder_dim_head = decoder_dim_head + self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=8, + dim_head=self.dim_head, + mlp_dim=mlp_dim, dropout=0) + self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth, + heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0, + softmax=decoder_softmax) + + def _forward_semantic_tokens(self, x): + b, c, h, w = x.shape + spatial_attention = self.conv_a(x) + spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() + spatial_attention = torch.softmax(spatial_attention, dim=-1) + x = x.view([b, c, -1]).contiguous() + tokens = torch.einsum('bln,bcn->blc', spatial_attention, x) + + return tokens + + def _forward_reshape_tokens(self, x): + # b,c,h,w = x.shape + if self.pool_mode is 'max': + x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size]) + elif self.pool_mode is 'ave': + x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size]) + else: + x = x + tokens = rearrange(x, 'b c h w -> b (h w) c') + return tokens + + def _forward_transformer(self, x): + if self.with_pos: + x += self.pos_embedding + x = self.transformer(x) + return x + + def _forward_transformer_decoder(self, x, m): + b, c, h, w = x.shape + if self.with_decoder_pos == 'fix': + x = x + self.pos_embedding_decoder + elif self.with_decoder_pos == 'learned': + x = x + self.pos_embedding_decoder + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.transformer_decoder(x, m) + x = rearrange(x, 'b (h w) c -> b c h w', h=h) + return x + + def _forward_simple_decoder(self, x, m): + b, c, h, w = x.shape + b, l, c = m.shape + m = m.expand([h,w,b,l,c]) + m = rearrange(m, 'h w b l c -> l b c h w') + m = m.sum(0) + x = x + m + return x + + def forward(self, x): + # forward backbone resnet + x1 = x[:, :3, :, :] + x2 = x[:, 3:, :, :] + x1, x1_64, x1_64_2, x1_128, x1_256, x1_512 = self.forward_single(x1) + x2, x2_64, x2_64_2, x2_128, x2_256, x2_512 = self.forward_single(x2) + print(x1.shape, x1_64.shape, x1_64_2.shape, x1_128.shape, x1_256.shape, x1_512.shape) + + # forward tokenzier + if self.tokenizer: + token1 = self._forward_semantic_tokens(x1) + token2 = self._forward_semantic_tokens(x2) + else: + token1 = self._forward_reshape_tokens(x1) + token2 = self._forward_reshape_tokens(x2) + # forward transformer encoder + if self.token_trans: + self.tokens_ = torch.cat([token1, token2], dim=1) + self.tokens = self._forward_transformer(self.tokens_) + token1, token2 = self.tokens.chunk(2, dim=1) + # forward transformer decoder + if self.with_decoder: + x1 = self._forward_transformer_decoder(x1, token1) + x2 = self._forward_transformer_decoder(x2, token2) + else: + x1 = self._forward_simple_decoder(x1, token1) + x2 = self._forward_simple_decoder(x2, token2) + # feature differencing + x = torch.abs(x1 - x2) + if not self.if_upsample_2x: + x = self.upsamplex2(x) + x = self.upsamplex4(x) + # forward small cnn + x = self.classifier(x) + if self.output_sigmoid: + x = self.sigmoid(x) + return x diff --git a/xBD_code/zoo/senet.py b/xBD_code/zoo/senet.py new file mode 100644 index 0000000..86b4706 --- /dev/null +++ b/xBD_code/zoo/senet.py @@ -0,0 +1,561 @@ +""" +ResNet code gently borrowed from +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +""" + +from collections import OrderedDict +import math + +import torch +import torch.nn as nn +from torch.utils import model_zoo + +__all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', + 'se_resnext50_32x4d', 'se_resnext101_32x4d'] + +pretrained_settings = { + 'senet154': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet50': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet101': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet152': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnext50_32x4d': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnext101_32x4d': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, +} + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction, concat=False): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, + padding=0) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, + padding=0) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + +class SCSEModule(nn.Module): + # according to https://arxiv.org/pdf/1808.08127.pdf concat is better + def __init__(self, channels, reduction=16, concat=False): + super(SCSEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, + padding=0) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, + padding=0) + self.sigmoid = nn.Sigmoid() + + self.spatial_se = nn.Sequential(nn.Conv2d(channels, 1, kernel_size=1, + stride=1, padding=0, bias=False), + nn.Sigmoid()) + self.concat = concat + + def forward(self, x): + module_input = x + + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + chn_se = self.sigmoid(x) + chn_se = chn_se * module_input + + spa_se = self.spatial_se(module_input) + spa_se = module_input * spa_se + if self.concat: + return torch.cat([chn_se, spa_se], dim=1) + else: + return chn_se + spa_se + +class Bottleneck(nn.Module): + """ + Base class for bottlenecks that implements `forward()` method. + """ + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_module(out) + residual + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, + stride=stride, padding=1, groups=groups, + bias=False) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SCSEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SCSEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, + stride=stride, padding=1, groups=groups, + bias=False) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SCSEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBottleneck(Bottleneck): + """ + ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe + implementation and uses `stride=stride` in `conv1` and not in `conv2` + (the latter is used in the torchvision implementation of ResNet). + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEResNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, + stride=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, + groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNeXtBottleneck(Bottleneck): + """ + ResNeXt bottleneck type C with a Squeeze-and-Excitation module. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None, base_width=4): + super(SEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, + stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + + +class SCSEResNeXtBottleneck(Bottleneck): + """ + ResNeXt bottleneck type C with a Concurrent Spatial Squeeze-and-Excitation module. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None, base_width=4, final=False): + super(SCSEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, + stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SCSEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SENet(nn.Module): + + def __init__(self, block, layers, groups, reduction, dropout_p=0.2, + inplanes=128, input_3x3=True, downsample_kernel_size=3, + downsample_padding=1, num_classes=1000): + """ + Parameters + ---------- + block (nn.Module): Bottleneck class. + - For SENet154: SEBottleneck + - For SE-ResNet models: SEResNetBottleneck + - For SE-ResNeXt models: SEResNeXtBottleneck + layers (list of ints): Number of residual blocks for 4 layers of the + network (layer1...layer4). + groups (int): Number of groups for the 3x3 convolution in each + bottleneck block. + - For SENet154: 64 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 32 + reduction (int): Reduction ratio for Squeeze-and-Excitation modules. + - For all models: 16 + dropout_p (float or None): Drop probability for the Dropout layer. + If `None` the Dropout layer is not used. + - For SENet154: 0.2 + - For SE-ResNet models: None + - For SE-ResNeXt models: None + inplanes (int): Number of input channels for layer1. + - For SENet154: 128 + - For SE-ResNet models: 64 + - For SE-ResNeXt models: 64 + input_3x3 (bool): If `True`, use three 3x3 convolutions instead of + a single 7x7 convolution in layer0. + - For SENet154: True + - For SE-ResNet models: False + - For SE-ResNeXt models: False + downsample_kernel_size (int): Kernel size for downsampling convolutions + in layer2, layer3 and layer4. + - For SENet154: 3 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 1 + downsample_padding (int): Padding for downsampling convolutions in + layer2, layer3 and layer4. + - For SENet154: 1 + - For SE-ResNet models: 0 + - For SE-ResNeXt models: 0 + num_classes (int): Number of outputs in `last_linear` layer. + - For all models: 1000 + """ + super(SENet, self).__init__() + self.inplanes = inplanes + if input_3x3: + layer0_modules = [ + ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, + bias=False)), + ('bn1', nn.BatchNorm2d(64)), + ('relu1', nn.ReLU(inplace=True)), + ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, + bias=False)), + ('bn2', nn.BatchNorm2d(64)), + ('relu2', nn.ReLU(inplace=True)), + ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, + bias=False)), + ('bn3', nn.BatchNorm2d(inplanes)), + ('relu3', nn.ReLU(inplace=True)), + ] + else: + layer0_modules = [ + ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, + padding=3, bias=False)), + ('bn1', nn.BatchNorm2d(inplanes)), + ('relu1', nn.ReLU(inplace=True)), + ] + # To preserve compatibility with Caffe weights `ceil_mode=True` + # is used instead of `padding=1`. + self.pool = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0 + ) + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.avg_pool = nn.AvgPool2d(7, stride=1) + self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None + self.last_linear = nn.Linear(512 * block.expansion, num_classes) + self._initialize_weights() + + def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, + downsample_kernel_size=1, downsample_padding=0): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=downsample_kernel_size, stride=stride, + padding=downsample_padding, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, groups, reduction, stride, + downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + m.weight.data = nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def features(self, x): + x = self.layer0(x) + x = self.pool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def logits(self, x): + x = self.avg_pool(x) + if self.dropout is not None: + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, x): + x = self.features(x) + x = self.logits(x) + return x + + +def initialize_pretrained_model(model, num_classes, settings): + assert num_classes == settings['num_classes'], \ + 'num_classes should be {}, but is {}'.format( + settings['num_classes'], num_classes) + model.load_state_dict(model_zoo.load_url(settings['url']), strict=False) + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + model.mean = settings['mean'] + model.std = settings['std'] + + +def senet154(num_classes=1000, pretrained='imagenet'): + model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, + dropout_p=0.2, num_classes=num_classes) + #if pretrained is not None: + # settings = pretrained_settings['senet154'][pretrained] + # initialize_pretrained_model(model, num_classes, settings) + return model + +def scsenet154(num_classes=1000, pretrained='imagenet'): + print("scsenet154") + model = SENet(SCSEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, + dropout_p=0.2, num_classes=num_classes) + #if pretrained is not None: + # settings = pretrained_settings['senet154'][pretrained] + # initialize_pretrained_model(model, num_classes, settings) + return model + + +def se_resnet50(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + #if pretrained is not None: + # settings = pretrained_settings['se_resnet50'][pretrained] + # initialize_pretrained_model(model, num_classes, settings) + return model + + +def se_resnet101(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + #if pretrained is not None: + # settings = pretrained_settings['se_resnet101'][pretrained] + # initialize_pretrained_model(model, num_classes, settings) + return model + + +def se_resnet152(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + #if pretrained is not None: + # settings = pretrained_settings['se_resnet152'][pretrained] + # initialize_pretrained_model(model, num_classes, settings) + return model + + +def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + #if pretrained is not None: + # settings = pretrained_settings['se_resnext50_32x4d'][pretrained] + # initialize_pretrained_model(model, num_classes, settings) + return model + + +def scse_resnext50_32x4d(num_classes=1000, pretrained='imagenet'): + model = SENet(SCSEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + #if pretrained is not None: + # settings = pretrained_settings['se_resnext50_32x4d'][pretrained] + # initialize_pretrained_model(model, num_classes, settings) + return model + + +def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + #if pretrained is not None: + # settings = pretrained_settings['se_resnext101_32x4d'][pretrained] + # initialize_pretrained_model(model, num_classes, settings) + return model +if __name__ == '__main__': + + + print(se_resnext50_32x4d()) +