From be683689db833c312a27ef35e8ca517090109962 Mon Sep 17 00:00:00 2001
From: Kartikaeya <kartikaeya@gmail.com>
Date: Sun, 22 Oct 2023 17:07:54 -0400
Subject: [PATCH] added code for training with bgr

---
 dataset/videomatte.py     |  24 ++--
 inference.py              |   4 +
 inference_utils.py        |  19 ++-
 model/decoder.py          |  10 +-
 model/mobilenetv3.py      |  31 ++++-
 model/model.py            |   5 +-
 requirements_training.txt |  10 +-
 train.py                  | 271 ++++++++++++++++++++++----------------
 train_config.py           |  42 +++---
 9 files changed, 247 insertions(+), 169 deletions(-)

diff --git a/dataset/videomatte.py b/dataset/videomatte.py
index 555911b..2d86849 100644
--- a/dataset/videomatte.py
+++ b/dataset/videomatte.py
@@ -9,14 +9,14 @@
 class VideoMatteDataset(Dataset):
     def __init__(self,
                  videomatte_dir,
-                 background_image_dir,
+                #  background_image_dir,
                  background_video_dir,
                  size,
                  seq_length,
                  seq_sampler,
                  transform=None):
-        self.background_image_dir = background_image_dir
-        self.background_image_files = os.listdir(background_image_dir)
+        # self.background_image_dir = background_image_dir
+        # self.background_image_files = os.listdir(background_image_dir)
         self.background_video_dir = background_video_dir
         self.background_video_clips = sorted(os.listdir(background_video_dir))
         self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip)))
@@ -38,10 +38,10 @@ def __len__(self):
         return len(self.videomatte_idx)
     
     def __getitem__(self, idx):
-        if random.random() < 0.5:
-            bgrs = self._get_random_image_background()
-        else:
-            bgrs = self._get_random_video_background()
+        # if random.random() < 0.5:
+        #     bgrs = self._get_random_image_background()
+        # else:
+        bgrs = self._get_random_video_background()
         
         fgrs, phas = self._get_videomatte(idx)
         
@@ -50,11 +50,11 @@ def __getitem__(self, idx):
         
         return fgrs, phas, bgrs
     
-    def _get_random_image_background(self):
-        with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
-            bgr = self._downsample_if_needed(bgr.convert('RGB'))
-        bgrs = [bgr] * self.seq_length
-        return bgrs
+    # def _get_random_image_background(self):
+    #     with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
+    #         bgr = self._downsample_if_needed(bgr.convert('RGB'))
+    #     bgrs = [bgr] * self.seq_length
+    #     return bgrs
     
     def _get_random_video_background(self):
         clip_idx = random.choice(range(len(self.background_video_clips)))
diff --git a/inference.py b/inference.py
index a116754..3a1d1b3 100644
--- a/inference.py
+++ b/inference.py
@@ -120,6 +120,10 @@ def convert_video(model,
             rec = [None] * 4
             for src in reader:
 
+                if src.shape[-1] %2 == 1:
+                    src = src[:, :, :, :-1]
+                if src.shape[-2] %2 == 1:
+                    src = src[:, :, :-1, :]
                 if downsample_ratio is None:
                     downsample_ratio = auto_downsample_ratio(*src.shape[2:])
 
diff --git a/inference_utils.py b/inference_utils.py
index d651dc0..c6b4111 100644
--- a/inference_utils.py
+++ b/inference_utils.py
@@ -5,7 +5,7 @@
 from torch.utils.data import Dataset
 from torchvision.transforms.functional import to_pil_image
 from PIL import Image
-
+import torch
 
 class VideoReader(Dataset):
     def __init__(self, path, transform=None):
@@ -55,18 +55,23 @@ def close(self):
 class ImageSequenceReader(Dataset):
     def __init__(self, path, transform=None):
         self.path = path
-        self.files = sorted(os.listdir(path))
+        self.files_fgr = sorted(os.listdir(path + "fgr/"))
+        self.files_bgr = sorted(os.listdir(path + "bgr/"))
         self.transform = transform
         
     def __len__(self):
-        return len(self.files)
+        return len(self.files_fgr)
     
     def __getitem__(self, idx):
-        with Image.open(os.path.join(self.path, self.files[idx])) as img:
-            img.load()
+        with Image.open(os.path.join(self.path + "fgr/", self.files_fgr[idx])) as fgr_img:
+            fgr_img.load()
+
+        with Image.open(os.path.join(self.path + "bgr/", self.files_bgr[idx])) as bgr_img:
+            bgr_img.load()
+        
         if self.transform is not None:
-            return self.transform(img)
-        return img
+            return torch.cat([self.transform(fgr_img), self.transform(bgr_img)], dim = 0)
+        return fgr_img
 
 
 class ImageSequenceWriter:
diff --git a/model/decoder.py b/model/decoder.py
index 7307435..7569429 100644
--- a/model/decoder.py
+++ b/model/decoder.py
@@ -1,7 +1,7 @@
 import torch
 from torch import Tensor
 from torch import nn
-from torch.nn import functional as F
+# from torch.nn import functional as F
 from typing import Tuple, Optional
 
 class RecurrentDecoder(nn.Module):
@@ -9,10 +9,10 @@ def __init__(self, feature_channels, decoder_channels):
         super().__init__()
         self.avgpool = AvgPool()
         self.decode4 = BottleneckBlock(feature_channels[3])
-        self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
-        self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
-        self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
-        self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
+        self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 6, decoder_channels[0])
+        self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 6, decoder_channels[1])
+        self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 6, decoder_channels[2])
+        self.decode0 = OutputBlock(decoder_channels[2], 6, decoder_channels[3])
 
     def forward(self,
                 s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,
diff --git a/model/mobilenetv3.py b/model/mobilenetv3.py
index 712a298..271a93e 100644
--- a/model/mobilenetv3.py
+++ b/model/mobilenetv3.py
@@ -3,6 +3,21 @@
 from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
 from torchvision.transforms.functional import normalize
 
+def load_matched_state_dict(model, state_dict, print_stats=True):
+    """
+    Only loads weights that matched in key and shape. Ignore other weights.
+    """
+    num_matched, num_total = 0, 0
+    curr_state_dict = model.state_dict()
+    for key in curr_state_dict.keys():
+        num_total += 1
+        if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
+            curr_state_dict[key] = state_dict[key]
+            num_matched += 1
+    model.load_state_dict(curr_state_dict)
+    if print_stats:
+        print(f'Loaded state_dict: {num_matched}/{num_total} matched')
+
 class MobileNetV3LargeEncoder(MobileNetV3):
     def __init__(self, pretrained: bool = False):
         super().__init__(
@@ -27,14 +42,24 @@ def __init__(self, pretrained: bool = False):
         )
         
         if pretrained:
-            self.load_state_dict(torch.hub.load_state_dict_from_url(
-                'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
+            pretrained_state_dict = torch.hub.load_state_dict_from_url(
+                'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')
+            
+            # print("pretrained_state_dict keys \n \n ", pretrained_state_dict.keys())
+            
+            # print("\n\ncurrent model state dict keys \n\n", self.state_dict().keys())
+
+            load_matched_state_dict(self, pretrained_state_dict)
+
+            # self.load_state_dict(torch.hub.load_state_dict_from_url(
+            #     'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
 
         del self.avgpool
         del self.classifier
         
     def forward_single_frame(self, x):
-        x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+        # print(x.shape)
+        x = torch.cat((normalize(x[:, :3, ...], [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), normalize(x[:, 3:, ...], [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])), dim = -3)
         
         x = self.features[0](x)
         x = self.features[1](x)
diff --git a/model/model.py b/model/model.py
index 71fc684..bd47f0c 100644
--- a/model/model.py
+++ b/model/model.py
@@ -1,6 +1,7 @@
 import torch
 from torch import Tensor
 from torch import nn
+from torchsummary import summary
 from torch.nn import functional as F
 from typing import Optional, List
 
@@ -58,8 +59,8 @@ def forward(self,
         if not segmentation_pass:
             fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
             if downsample_ratio != 1:
-                fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
-            fgr = fgr_residual + src
+                fgr_residual, pha = self.refiner(src[:, :, :3, ...], src_sm[:, :, :3, ...], fgr_residual, pha, hid)
+            fgr = fgr_residual + src[:, :, :3, ...]
             fgr = fgr.clamp(0., 1.)
             pha = pha.clamp(0., 1.)
             return [fgr, pha, *rec]
diff --git a/requirements_training.txt b/requirements_training.txt
index 70fd4b1..0ac0666 100644
--- a/requirements_training.txt
+++ b/requirements_training.txt
@@ -1,5 +1,7 @@
 easing_functions==1.0.4
-tensorboard==2.5.0
-torch==1.9.0
-torchvision==0.10.0
-tqdm==4.61.1
\ No newline at end of file
+tensorboard
+torch
+torchvision
+tqdm==4.61.1
+opencv-python==4.6.0.66
+torchsummary
\ No newline at end of file
diff --git a/train.py b/train.py
index 462bd1f..a542add 100644
--- a/train.py
+++ b/train.py
@@ -121,7 +121,8 @@
 from model import MattingNetwork
 from train_config import DATA_PATHS
 from train_loss import matting_loss, segmentation_loss
-
+import kornia
+from torchvision import transforms as T
 
 class Trainer:
     def __init__(self, rank, world_size):
@@ -189,7 +190,7 @@ def init_datasets(self):
         if self.args.dataset == 'videomatte':
             self.dataset_lr_train = VideoMatteDataset(
                 videomatte_dir=DATA_PATHS['videomatte']['train'],
-                background_image_dir=DATA_PATHS['background_images']['train'],
+                # background_image_dir=DATA_PATHS['background_images']['train'],
                 background_video_dir=DATA_PATHS['background_videos']['train'],
                 size=self.args.resolution_lr,
                 seq_length=self.args.seq_length_lr,
@@ -198,7 +199,7 @@ def init_datasets(self):
             if self.args.train_hr:
                 self.dataset_hr_train = VideoMatteDataset(
                     videomatte_dir=DATA_PATHS['videomatte']['train'],
-                    background_image_dir=DATA_PATHS['background_images']['train'],
+                    # background_image_dir=DATA_PATHS['background_images']['train'],
                     background_video_dir=DATA_PATHS['background_videos']['train'],
                     size=self.args.resolution_hr,
                     seq_length=self.args.seq_length_hr,
@@ -206,38 +207,38 @@ def init_datasets(self):
                     transform=VideoMatteTrainAugmentation(size_hr))
             self.dataset_valid = VideoMatteDataset(
                 videomatte_dir=DATA_PATHS['videomatte']['valid'],
-                background_image_dir=DATA_PATHS['background_images']['valid'],
+                # background_image_dir=DATA_PATHS['background_images']['valid'],
                 background_video_dir=DATA_PATHS['background_videos']['valid'],
                 size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
                 seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
                 seq_sampler=ValidFrameSampler(),
                 transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr))
-        else:
-            self.dataset_lr_train = ImageMatteDataset(
-                imagematte_dir=DATA_PATHS['imagematte']['train'],
-                background_image_dir=DATA_PATHS['background_images']['train'],
-                background_video_dir=DATA_PATHS['background_videos']['train'],
-                size=self.args.resolution_lr,
-                seq_length=self.args.seq_length_lr,
-                seq_sampler=TrainFrameSampler(),
-                transform=ImageMatteAugmentation(size_lr))
-            if self.args.train_hr:
-                self.dataset_hr_train = ImageMatteDataset(
-                    imagematte_dir=DATA_PATHS['imagematte']['train'],
-                    background_image_dir=DATA_PATHS['background_images']['train'],
-                    background_video_dir=DATA_PATHS['background_videos']['train'],
-                    size=self.args.resolution_hr,
-                    seq_length=self.args.seq_length_hr,
-                    seq_sampler=TrainFrameSampler(),
-                    transform=ImageMatteAugmentation(size_hr))
-            self.dataset_valid = ImageMatteDataset(
-                imagematte_dir=DATA_PATHS['imagematte']['valid'],
-                background_image_dir=DATA_PATHS['background_images']['valid'],
-                background_video_dir=DATA_PATHS['background_videos']['valid'],
-                size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
-                seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
-                seq_sampler=ValidFrameSampler(),
-                transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
+        # else:
+        #     self.dataset_lr_train = ImageMatteDataset(
+        #         imagematte_dir=DATA_PATHS['imagematte']['train'],
+        #         background_image_dir=DATA_PATHS['background_images']['train'],
+        #         background_video_dir=DATA_PATHS['background_videos']['train'],
+        #         size=self.args.resolution_lr,
+        #         seq_length=self.args.seq_length_lr,
+        #         seq_sampler=TrainFrameSampler(),
+        #         transform=ImageMatteAugmentation(size_lr))
+        #     if self.args.train_hr:
+        #         self.dataset_hr_train = ImageMatteDataset(
+        #             imagematte_dir=DATA_PATHS['imagematte']['train'],
+        #             background_image_dir=DATA_PATHS['background_images']['train'],
+        #             background_video_dir=DATA_PATHS['background_videos']['train'],
+        #             size=self.args.resolution_hr,
+        #             seq_length=self.args.seq_length_hr,
+        #             seq_sampler=TrainFrameSampler(),
+        #             transform=ImageMatteAugmentation(size_hr))
+        #     self.dataset_valid = ImageMatteDataset(
+        #         imagematte_dir=DATA_PATHS['imagematte']['valid'],
+        #         background_image_dir=DATA_PATHS['background_images']['valid'],
+        #         background_video_dir=DATA_PATHS['background_videos']['valid'],
+        #         size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
+        #         seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
+        #         seq_sampler=ValidFrameSampler(),
+        #         transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
             
         # Matting dataloaders:
         self.datasampler_lr_train = DistributedSampler(
@@ -270,49 +271,49 @@ def init_datasets(self):
             pin_memory=True)
         
         # Segementation datasets
-        self.log('Initializing image segmentation datasets')
-        self.dataset_seg_image = ConcatDataset([
-            CocoPanopticDataset(
-                imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
-                anndir=DATA_PATHS['coco_panoptic']['anndir'],
-                annfile=DATA_PATHS['coco_panoptic']['annfile'],
-                transform=CocoPanopticTrainAugmentation(size_lr)),
-            SuperviselyPersonDataset(
-                imgdir=DATA_PATHS['spd']['imgdir'],
-                segdir=DATA_PATHS['spd']['segdir'],
-                transform=CocoPanopticTrainAugmentation(size_lr))
-        ])
-        self.datasampler_seg_image = DistributedSampler(
-            dataset=self.dataset_seg_image,
-            rank=self.rank,
-            num_replicas=self.world_size,
-            shuffle=True)
-        self.dataloader_seg_image = DataLoader(
-            dataset=self.dataset_seg_image,
-            batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
-            num_workers=self.args.num_workers,
-            sampler=self.datasampler_seg_image,
-            pin_memory=True)
+        # self.log('Initializing image segmentation datasets')
+        # self.dataset_seg_image = ConcatDataset([
+        #     CocoPanopticDataset(
+        #         imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
+        #         anndir=DATA_PATHS['coco_panoptic']['anndir'],
+        #         annfile=DATA_PATHS['coco_panoptic']['annfile'],
+        #         transform=CocoPanopticTrainAugmentation(size_lr)),
+        #     SuperviselyPersonDataset(
+        #         imgdir=DATA_PATHS['spd']['imgdir'],
+        #         segdir=DATA_PATHS['spd']['segdir'],
+        #         transform=CocoPanopticTrainAugmentation(size_lr))
+        # ])
+        # self.datasampler_seg_image = DistributedSampler(
+        #     dataset=self.dataset_seg_image,
+        #     rank=self.rank,
+        #     num_replicas=self.world_size,
+        #     shuffle=True)
+        # self.dataloader_seg_image = DataLoader(
+        #     dataset=self.dataset_seg_image,
+        #     batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
+        #     num_workers=self.args.num_workers,
+        #     sampler=self.datasampler_seg_image,
+        #     pin_memory=True)
         
-        self.log('Initializing video segmentation datasets')
-        self.dataset_seg_video = YouTubeVISDataset(
-            videodir=DATA_PATHS['youtubevis']['videodir'],
-            annfile=DATA_PATHS['youtubevis']['annfile'],
-            size=self.args.resolution_lr,
-            seq_length=self.args.seq_length_lr,
-            seq_sampler=TrainFrameSampler(speed=[1]),
-            transform=YouTubeVISAugmentation(size_lr))
-        self.datasampler_seg_video = DistributedSampler(
-            dataset=self.dataset_seg_video,
-            rank=self.rank,
-            num_replicas=self.world_size,
-            shuffle=True)
-        self.dataloader_seg_video = DataLoader(
-            dataset=self.dataset_seg_video,
-            batch_size=self.args.batch_size_per_gpu,
-            num_workers=self.args.num_workers,
-            sampler=self.datasampler_seg_video,
-            pin_memory=True)
+        # self.log('Initializing video segmentation datasets')
+        # self.dataset_seg_video = YouTubeVISDataset(
+        #     videodir=DATA_PATHS['youtubevis']['videodir'],
+        #     annfile=DATA_PATHS['youtubevis']['annfile'],
+        #     size=self.args.resolution_lr,
+        #     seq_length=self.args.seq_length_lr,
+        #     seq_sampler=TrainFrameSampler(speed=[1]),
+        #     transform=YouTubeVISAugmentation(size_lr))
+        # self.datasampler_seg_video = DistributedSampler(
+        #     dataset=self.dataset_seg_video,
+        #     rank=self.rank,
+        #     num_replicas=self.world_size,
+        #     shuffle=True)
+        # self.dataloader_seg_video = DataLoader(
+        #     dataset=self.dataset_seg_video,
+        #     batch_size=self.args.batch_size_per_gpu,
+        #     num_workers=self.args.num_workers,
+        #     sampler=self.datasampler_seg_video,
+        #     pin_memory=True)
         
     def init_model(self):
         self.log('Initializing model')
@@ -359,12 +360,12 @@ def train(self):
                     self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr')
                 
                 # Segmentation pass
-                if self.step % 2 == 0:
-                    true_img, true_seg = self.load_next_seg_video_sample()
-                    self.train_seg(true_img, true_seg, log_label='seg_video')
-                else:
-                    true_img, true_seg = self.load_next_seg_image_sample()
-                    self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
+                # if self.step % 2 == 0:
+                #     true_img, true_seg = self.load_next_seg_video_sample()
+                #     self.train_seg(true_img, true_seg, log_label='seg_video')
+                # else:
+                #     true_img, true_seg = self.load_next_seg_image_sample()
+                #     self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
                     
                 if self.step % self.args.checkpoint_save_interval == 0:
                     self.save()
@@ -376,10 +377,47 @@ def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag):
         true_pha = true_pha.to(self.rank, non_blocking=True)
         true_bgr = true_bgr.to(self.rank, non_blocking=True)
         true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr)
-        true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
         
+        true_src = true_bgr.clone()
+        
+        # Augment bgr with shadow
+        aug_shadow_idx = torch.rand(len(true_src)) < 0.3
+        if aug_shadow_idx.any():
+            aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()).flatten(start_dim = 0, end_dim = 1)
+            aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
+            aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
+            expected_shape = torch.tensor(true_src[aug_shadow_idx].shape)
+            expected_shape[2] = -1
+            true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow.reshape(expected_shape.tolist())).clamp_(0, 1)
+            del aug_shadow
+        del aug_shadow_idx
+        
+        # Composite foreground onto source
+        true_src = true_fgr * true_pha + true_src * (1 - true_pha)
+
+        # Augment with noise
+        aug_noise_idx = torch.rand(len(true_src)) < 0.4
+        if aug_noise_idx.any():
+            true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
+            true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
+        del aug_noise_idx
+        
+        # Augment background with jitter
+        aug_jitter_idx = torch.rand(len(true_src)) < 0.8
+        if aug_jitter_idx.any():
+            true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx].flatten(start_dim = 0, end_dim = 1)).reshape(true_bgr[aug_jitter_idx].shape)
+        del aug_jitter_idx
+        
+        # Augment background with affine
+        aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
+        if aug_affine_idx.any():
+            true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx].flatten(start_dim = 0, end_dim = 1)).reshape(true_bgr[aug_affine_idx].shape)
+        del aug_affine_idx
+
+        fg_bg_input = torch.cat((true_src, true_bgr), dim = -3)
+
         with autocast(enabled=not self.args.disable_mixed_precision):
-            pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2]
+            pred_fgr, pred_pha = self.model_ddp(fg_bg_input, downsample_ratio=downsample_ratio)[:2]
             loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)
 
         self.scaler.scale(loss['total']).backward()
@@ -397,29 +435,30 @@ def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag):
             self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step)
             self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step)
             self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step)
-            
-    def train_seg(self, true_img, true_seg, log_label):
-        true_img = true_img.to(self.rank, non_blocking=True)
-        true_seg = true_seg.to(self.rank, non_blocking=True)
+    
+    # does not get called    
+    # def train_seg(self, true_img, true_seg, log_label):
+    #     true_img = true_img.to(self.rank, non_blocking=True)
+    #     true_seg = true_seg.to(self.rank, non_blocking=True)
         
-        true_img, true_seg = self.random_crop(true_img, true_seg)
+    #     true_img, true_seg = self.random_crop(true_img, true_seg)
         
-        with autocast(enabled=not self.args.disable_mixed_precision):
-            pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
-            loss = segmentation_loss(pred_seg, true_seg)
+    #     with autocast(enabled=not self.args.disable_mixed_precision):
+    #         pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
+    #         loss = segmentation_loss(pred_seg, true_seg)
         
-        self.scaler.scale(loss).backward()
-        self.scaler.step(self.optimizer)
-        self.scaler.update()
-        self.optimizer.zero_grad()
+    #     self.scaler.scale(loss).backward()
+    #     self.scaler.step(self.optimizer)
+    #     self.scaler.update()
+    #     self.optimizer.zero_grad()
         
-        if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
-            self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
+    #     if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
+    #         self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
         
-        if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
-            self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
-            self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
-            self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
+    #     if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
+    #         self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
+    #         self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
+    #         self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
     
     def load_next_mat_hr_sample(self):
         try:
@@ -430,23 +469,23 @@ def load_next_mat_hr_sample(self):
             sample = next(self.dataiterator_mat_hr)
         return sample
     
-    def load_next_seg_video_sample(self):
-        try:
-            sample = next(self.dataiterator_seg_video)
-        except:
-            self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
-            self.dataiterator_seg_video = iter(self.dataloader_seg_video)
-            sample = next(self.dataiterator_seg_video)
-        return sample
+    # def load_next_seg_video_sample(self):
+    #     try:
+    #         sample = next(self.dataiterator_seg_video)
+    #     except:
+    #         self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
+    #         self.dataiterator_seg_video = iter(self.dataloader_seg_video)
+    #         sample = next(self.dataiterator_seg_video)
+    #     return sample
     
-    def load_next_seg_image_sample(self):
-        try:
-            sample = next(self.dataiterator_seg_image)
-        except:
-            self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
-            self.dataiterator_seg_image = iter(self.dataloader_seg_image)
-            sample = next(self.dataiterator_seg_image)
-        return sample
+    # def load_next_seg_image_sample(self):
+    #     try:
+    #         sample = next(self.dataiterator_seg_image)
+    #     except:
+    #         self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
+    #         self.dataiterator_seg_image = iter(self.dataloader_seg_image)
+    #         sample = next(self.dataiterator_seg_image)
+    #     return sample
     
     def validate(self):
         if self.rank == 0:
@@ -461,7 +500,9 @@ def validate(self):
                         true_bgr = true_bgr.to(self.rank, non_blocking=True)
                         true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
                         batch_size = true_src.size(0)
-                        pred_fgr, pred_pha = self.model(true_src)[:2]
+
+                        fg_bg_input = torch.cat((true_src, true_bgr), dim = -3)
+                        pred_fgr, pred_pha = self.model(fg_bg_input)[:2]
                         total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size
                         total_count += batch_size
             avg_loss = total_loss / total_count
diff --git a/train_config.py b/train_config.py
index 0792696..c122a55 100644
--- a/train_config.py
+++ b/train_config.py
@@ -37,32 +37,32 @@
         'train': '../matting-data/VideoMatte240K_JPEG_SD/train',
         'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid',
     },
-    'imagematte': {
-        'train': '../matting-data/ImageMatte/train',
-        'valid': '../matting-data/ImageMatte/valid',
-    },
-    'background_images': {
-        'train': '../matting-data/Backgrounds/train',
-        'valid': '../matting-data/Backgrounds/valid',
-    },
+    # 'imagematte': {
+    #     'train': '../matting-data/ImageMatte/train',
+    #     'valid': '../matting-data/ImageMatte/valid',
+    # },
+    # 'background_images': {
+    #     'train': '../matting-data/Backgrounds/train',
+    #     'valid': '../matting-data/Backgrounds/valid',
+    # },
     'background_videos': {
         'train': '../matting-data/BackgroundVideos/train',
         'valid': '../matting-data/BackgroundVideos/valid',
     },
     
     
-    'coco_panoptic': {
-        'imgdir': '../matting-data/coco/train2017/',
-        'anndir': '../matting-data/coco/panoptic_train2017/',
-        'annfile': '../matting-data/coco/annotations/panoptic_train2017.json',
-    },
-    'spd': {
-        'imgdir': '../matting-data/SuperviselyPersonDataset/img',
-        'segdir': '../matting-data/SuperviselyPersonDataset/seg',
-    },
-    'youtubevis': {
-        'videodir': '../matting-data/YouTubeVIS/train/JPEGImages',
-        'annfile': '../matting-data/YouTubeVIS/train/instances.json',
-    }
+    # 'coco_panoptic': {
+    #     'imgdir': '../matting-data/coco/train2017/',
+    #     'anndir': '../matting-data/coco/panoptic_train2017/',
+    #     'annfile': '../matting-data/coco/annotations/panoptic_train2017.json',
+    # },
+    # 'spd': {
+    #     'imgdir': '../matting-data/SuperviselyPersonDataset/img',
+    #     'segdir': '../matting-data/SuperviselyPersonDataset/seg',
+    # },
+    # 'youtubevis': {
+    #     'videodir': '../matting-data/YouTubeVIS/train/JPEGImages',
+    #     'annfile': '../matting-data/YouTubeVIS/train/instances.json',
+    # }
     
 }