Skip to content
This repository has been archived by the owner on May 18, 2022. It is now read-only.

Commit

Permalink
Updated multi-dataset training warp module to run with this new repo
Browse files Browse the repository at this point in the history
Works on VITON_VVT_MPV dataset
  • Loading branch information
andrewjong committed Sep 24, 2020
1 parent f8ea092 commit d82fbac
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 50 deletions.
17 changes: 12 additions & 5 deletions datasets/mpv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@ class MPVDataset(TryonDataset):
""" CP-VTON dataset with the MPV folder structure. """

@staticmethod
def modify_commandline_options(parser: argparse.ArgumentParser, is_train):
parser = super(MPVDataset, MPVDataset).modify_commandline_options(
parser, is_train
)
def modify_commandline_options(parser: argparse.ArgumentParser, is_train, shared=False):
if not shared:
parser = super(MPVDataset, MPVDataset).modify_commandline_options(
parser, is_train
)
parser.add_argument("--mpv_dataroot", default="/data_hdd/mpv_competition")
return parser

def __init__(self, opt):
super(MPVDataset, self).__init__(opt)

# @overrides(CpVtonDataset)
def load_file_paths(self):
def load_file_paths(self, i_am_validation=False):
""" Reads the datalist txt file for CP-VTON"""
self.root = self.opt.mpv_dataroot
self.image_names = []
Expand Down Expand Up @@ -77,3 +78,9 @@ def get_person_cocopose_path(self, index):
pose_path = osp.join(self.root, "all_person_clothes_keypoints", image_name)
pose_path = pose_path.replace(".jpg", "_keypoints.json")
return pose_path

def get_person_densepose_path(self, index):
return NotImplementedError("THIS IS TODO. For now use cocopose on MPV")

def get_person_flow_path(self, index):
return NotImplementedError("THIS IS TODO. Image datasets don't have flow")
29 changes: 22 additions & 7 deletions datasets/n_frames_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
from abc import ABC, abstractmethod
from argparse import ArgumentParser
from typing import Dict

import torch
from torch.utils.data.dataloader import default_collate
Expand Down Expand Up @@ -101,15 +102,29 @@ def wrapper(self, index):
return wrapper


def maybe_combine_frames_and_channels(opt, inputs):
""" if n_frames_total is true, combines frames and channels dim for all the tensors"""
def maybe_combine_frames_and_channels(opt, inputs: Dict, has_batch_dim=True):
"""
if n_frames_total is true, combines frames and channels dim for all the tensors.
For tuples, unpacks it from the list that wraps it.
Args:
opt:
inputs:
has_batch_dim: whether or not batch dim is already included. If called within
a dataset class, this should be False. If called as the output of a
dataloader, then should be True.
"""
if hasattr(opt, "n_frames_total"):

def maybe_combine(t):
# Tensor like items
if isinstance(t, torch.Tensor) and len(t.shape) == 5:
bs, n_frames, c, h, w = t.shape
t = t.view(bs, n_frames * c, h, w)
if isinstance(t, torch.Tensor):
if has_batch_dim and len(t.shape) == 5:
bs, n_frames, c, h, w = t.shape
t = t.view(bs, n_frames * c, h, w)
elif not has_batch_dim and len(t.shape) == 4:
n_frames, c, h, w = t.shape
t = t.view(n_frames * c, h, w)
# Non-tensor like items, such as lists of strings or numbers
elif isinstance(t, collections.abc.Sequence) and not isinstance(t, str):
# unpack
Expand All @@ -118,6 +133,6 @@ def maybe_combine(t):

return t

inputs = {k: maybe_combine(v) for k, v in inputs.items()}
new_inputs = {k: maybe_combine(v) for k, v in inputs.items()}

return inputs
return new_inputs
17 changes: 12 additions & 5 deletions datasets/viton_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ class VitonDataset(TryonDataset):
""" CP-VTON dataset with the original Viton folder structure """

@staticmethod
def modify_commandline_options(parser: argparse.ArgumentParser, is_train):
parser = super(VitonDataset, VitonDataset).modify_commandline_options(
parser, is_train
)
def modify_commandline_options(parser: argparse.ArgumentParser, is_train, shared=False):
if not shared:
parser = super(VitonDataset, VitonDataset).modify_commandline_options(
parser, is_train
)
parser.add_argument("--viton_dataroot", default="data")
parser.add_argument("--data_list", default="train_pairs.txt")
return parser
Expand All @@ -22,7 +23,7 @@ def __init__(self, opt):
self.data_path = osp.join(opt.viton_dataroot, opt.datamode)

# @overrides
def load_file_paths(self):
def load_file_paths(self, i_am_validation=False):
"""
Reads the datalist txt file for CP-VTON
sets self.image_names and self.cloth_names. they should correspond 1-to-1
Expand Down Expand Up @@ -88,3 +89,9 @@ def get_person_cocopose_path(self, index):
_pose_name = im_name.replace(".jpg", "_keypoints.json")
pose_path = osp.join(self.data_path, "pose", _pose_name)
return pose_path

def get_person_flow_path(self, index):
return NotImplementedError("THIS IS TODO. Image datasets don't have flow")

def get_person_densepose_path(self, index):
return NotImplementedError("THIS IS TODO. For now use cocopose on VITON")
42 changes: 17 additions & 25 deletions datasets/viton_vvt_mpv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from datasets import BaseDataset
from datasets.mpv_dataset import MPVDataset
from datasets.n_frames_interface import maybe_combine_frames_and_channels
from datasets.tryon_dataset import TryonDatasetType
from datasets.viton_dataset import VitonDataset
from datasets.vvt_dataset import VVTDataset
from options.train_options import TrainOptions


class VitonVvtMpvDataset(BaseDataset):
Expand All @@ -16,9 +17,9 @@ class VitonVvtMpvDataset(BaseDataset):

@staticmethod
def modify_commandline_options(parser: ArgumentParser, is_train):
parser = VitonDataset.modify_commandline_options(parser, is_train)
parser = VVTDataset.modify_commandline_options(parser, is_train)
parser = MPVDataset.modify_commandline_options(parser, is_train)
parser = VitonDataset.modify_commandline_options(parser, is_train, shared=True)
parser = MPVDataset.modify_commandline_options(parser, is_train, shared=True)
return parser

def name(self):
Expand All @@ -35,39 +36,30 @@ def __init__(self, opt):

self.transforms = transforms.Compose([])

@classmethod
def make_validation_dataset(self, opt) -> TryonDatasetType:
val = VVTDataset(opt, i_am_validation=True)
return val

def __getitem__(self, index):
if index < len(self.viton_dataset):
item = self.viton_dataset[index]
return item

index -= len(self.viton_dataset)

if index < len(self.vvt_dataset):
item = self.vvt_dataset[index]
if self.opt.model == "warp":
assert self.opt.n_frames_total == 1, (
f"{self.opt.n_frames_total=}; "
f"warp model shouldn't be using n_frames_total > 1"
)
item = maybe_combine_frames_and_channels(self.opt, item, has_batch_dim=False)
return item

index -= len(self.vvt_dataset)

item = self.mpv_dataset[index]
return item

def __len__(self):
return len(self.viton_dataset) + len(self.vvt_dataset) + len(self.mpv_dataset)


if __name__ == "__main__":
print("Check the dataset for geometric matching module!")

opt = TrainOptions().parse()

dataset = VitonVvtMpvDataset(opt)
data_loader = CPDataLoader(opt, dataset)

print(
"Size of the dataset: %05d, dataloader: %04d"
% (len(dataset), len(data_loader.data_loader))
)
first_item = dataset.__getitem__(0)
first_batch = data_loader.next_batch()

from IPython import embed

embed()
5 changes: 3 additions & 2 deletions datasets/vvt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ class VVTDataset(TryonDataset, NFramesInterface):
""" CP-VTON dataset with FW-GAN's VVT folder structure. """

@staticmethod
def modify_commandline_options(parser: argparse.ArgumentParser, is_train):
parser = TryonDataset.modify_commandline_options(parser, is_train)
def modify_commandline_options(parser: argparse.ArgumentParser, is_train, shared=False):
if not shared:
parser = TryonDataset.modify_commandline_options(parser, is_train)
parser = NFramesInterface.modify_commandline_options(parser, is_train)
parser.add_argument("--vvt_dataroot", default="/data_hdd/fw_gan_vvt")
parser.add_argument(
Expand Down
13 changes: 13 additions & 0 deletions docs/3_train.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,18 @@ COMING SOON


COMING SOON
```
python train.py \
--name train_warp-cloth_viton-vvt-mpv \
--model warp \
--dataset viton_vvt_mpv \
--viton_dataroot /data_hdd/cp-vton/viton_processed \
--vvt_dataroot /data_hdd/fw_gan_vvt \
--mpv_dataroot /data_hdd/mpv_competition \
--workers 4 \
--gpu_ids 0 \
--batch_size 16
--accumulated 4
```

</details>
11 changes: 7 additions & 4 deletions models/warp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from models.base_model import BaseModel
from util import get_and_cat_inputs
from models.networks.cpvton.warp import (

FeatureExtraction,
FeatureL2Norm,
FeatureCorrelation,
Expand All @@ -32,7 +31,9 @@ def modify_commandline_options(cls, parser: ArgumentParser, is_train):
parser = ArgumentParser(parents=[parser], add_help=False)
parser = super(WarpModel, cls).modify_commandline_options(parser, is_train)
parser.add_argument("--grid_size", type=int, default=5)
parser.set_defaults(person_inputs=("agnostic", "densepose"))
parser.set_defaults(person_inputs=("agnostic", "cocopose"))
# TODO: We don't have densepose created for VITON and MPV yet
# parser.set_defaults(person_inputs=("agnostic", "densepose"))
return parser

def __init__(self, hparams):
Expand Down Expand Up @@ -69,7 +70,7 @@ def forward(self, inputA, inputB):
grid = self.gridGen(theta)
return grid, theta

def training_step(self, batch, _):
def training_step(self, batch, idx, val=False):
batch = maybe_combine_frames_and_channels(self.hparams, batch)
# unpack
c = batch["cloth"]
Expand All @@ -86,7 +87,7 @@ def training_step(self, batch, _):
loss = F.l1_loss(self.warped_cloth, im_c)

# Logging
if self.global_step % self.hparams.display_count == 0:
if not val and self.global_step % self.hparams.display_count == 0:
self.visualize(batch)

tensorboard_scalars = {"epoch": self.current_epoch, "loss": loss}
Expand Down Expand Up @@ -133,6 +134,8 @@ def test_step(self, batch, batch_idx):
return result

def visualize(self, b, tag="train"):
if tag == "validation":
b = maybe_combine_frames_and_channels(self.hparams, b)
person_visuals = self.fetch_person_visuals(b)

visuals = [
Expand Down
2 changes: 1 addition & 1 deletion options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def initialize(self, parser: argparse.ArgumentParser):
"--val_frequency",
dest="val_check_interval",
type=str,
default="25000", # parsed later into int or float based on "."
default="0.125", # parsed later into int or float based on "."
help="If float, validate (and checkpoint) after this many epochs. "
"If int, validate after this many batches. If 0 or 0.0, validate "
"every step."
Expand Down
5 changes: 4 additions & 1 deletion visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@


def tensor_for_board(img_tensor):
assert (
img_tensor.ndim == 4
), f"something's not right, i'm not a standard img_tensor. {img_tensor.shape=}"
# map into [0,1]
tensor = (img_tensor.clone() + 1) * 0.5
try:
tensor.cpu().clamp(0, 1)
except:
tensor.float().cpu().clamp(0, 1)
if tensor.size(1) == 1: # masks, make it RGB
if tensor.shape[1] == 1: # masks, make it RGB
tensor = tensor.repeat(1, 3, 1, 1)

return tensor
Expand Down

0 comments on commit d82fbac

Please sign in to comment.