Our approach contains two models: the core SAMS-GAN and the auxiliary WarpModule. We also compare against a baseline UNet-Mask model (based on the TOM model from CP-VTON).
The WarpModule is used to pre-warp the garment image to the shape of the user. It can be treated as a block box. You can just download the pre-warped cloths here (COMING SOON) and place them in ${PROJECT_ROOT}/warp-cloth
. To warp cloths on your own data, we provide pretrained weights for the WarpModule that you can find here (COMING SOON).
View in Tensorboard
All training progress can be viewed in Tensorboard.
tensorboard --logdir experiments/
We can port forward a Tensorboard connection from a remote server like this:
ssh -N -L localhost:6006:localhost:6006 [email protected]
Common Train Options
Experiment Setup
--name
experiment name. Saves checkpoints and logs toexperiments/{name}
--gpu_ids
--workers
--keep_epochs
let the optimizer handle the learning rate for this many epochs--decay_epochs
linearly decay the learning rate for this many epochs (after completion of keep_epochs)
Data
--vvt_dataroot
path to FW-GAN VVT Dataset--warp_cloth_dir
path to pre-warped cloths generated by the WarpModule (default: warp-cloth)--batch_size
number of batches to run through model--person_inputs
type of person representation, generally agnostic + (cocopose or densepose)--cloth_inputs
type of cloth representation, (default: cloth)
Checkpointing and logging
--display_count
how often in steps to log to Tensorboard--save_count
how often in steps to save a checkpoint--checkpoint
resume training from this checkpoint (path to.ckpt
file)
Choosing Architecture Design
--self_attn
flag to include attention layers in model architecture--flow_warp
flag to add optical flow to model, requires n_frames_total > 1--activation
select activation function (relu, gelu, swish, or sine)
... and more! For a complete list of options, run python train.py --help
Instructions
python train.py \
--name train_shineon \
--model unet \
--batch 4 \
--person_inputs densepose agnostic \
--cloth_inputs cloth \
--val_check_interval 0.05 \
--self_attn \
--accumulated_batches 16 \
--activation gelu
--warp_cloth_dir /path/to/output/warp/cloth/directory
Instructions
python train.py \
--name train_warp \
--model warp \
--workers 4 \
--batch 4
Instructions
A general train command:
python train.py \
--name "SAMS-GAN_train" \
--model sams \
--ngf_pow_outer 6 \
--ngf_pow_inner 10 \
--n_frames_total 5 \
--n_frames_now 1 \
--batch_size 4 \
--workers 8
The SAMS-GAN generator is an encoder-decoder architecture. The outer layers start with higher resolution (hxw) and fewer features. The inner layers have lower resolution and more features. Unlike other models, SAMS does NOT use --ngf
for generator features.
Number of Layers
The number of features in the outer layers equals pow(ngf_power_base,
ngf_pow_outer
)
; by default, the outer layers have 2^6=64
features.
The number of features in the inner layers equals pow(ngf_power_base,
ngf_pow_inner
)
; by default, the inner layers have 2^10=1024
features.
Attention Layers
Self-Attentive Multispade (SAMs) layer indices can be chosen with:
--attention_middle_indices
for middle layers--attention_decoder_indices
for decoder layers.
Supports negative index selection, e.g. use --attention_decoder_indices -1 -2
to put
attention in the last two decoder layers.
SAMS-GAN has two discriminators: Multiscale
that operates on the current frame at different image resolutions, and Temporal
that operates at the past --n_frames_now
at a single image resolution.
Discriminator size is uniformly adjusted with --ndf
(default 64).
We use progressive video frame training to speed up generation convergence. We start by generating a single image, then manually increase the number of frames to the max that fits on the GPU.
--n_frames_total
. Sets the size of the generation buffer, and how many previous frames are fed into the generator as input. Aim for the max that fits on GPU, 5 or more is ideal. Note that this effectively scales up the batch size; choosing between batch size and n_frames_total is a trade-off.--n_frames_now
. The number of frames to actually train on right now. The rest of the frames are masked with 0s. You should progressively increase this value from 1 up to--n_frames_total
.