diff --git a/train.py b/train.py index 7813c038..6258e0a1 100644 --- a/train.py +++ b/train.py @@ -131,6 +131,7 @@ def parse_comma_separated_list(s): @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), required=True) # Optional features. @click.option('--cond', help='Train conditional model', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--cond-D-nofix', help='For old cond models trained w/o fix in D conditioning', type=bool, default=False, show_default=True) @click.option('--mirror', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--mirror-y', help='Enable dataset y-flips', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--aug', help='Augmentation mode', type=click.Choice(['noaug', 'ada', 'fixed']), default='ada', show_default=True) @@ -214,6 +215,8 @@ def main(**kwargs): c.G_kwargs.channel_base = c.D_kwargs.channel_base = opts.cbase c.G_kwargs.channel_max = c.D_kwargs.channel_max = opts.cmax c.G_kwargs.mapping_kwargs.num_layers = (8 if opts.cfg == 'stylegan2' else 2) if opts.map_depth is None else opts.map_depth + if not opts.cond_d_nofix: + c.D_kwargs.mapping_kwargs.num_layers = c.G_kwargs.mapping_kwargs.num_layers c.G_kwargs.mapping_kwargs.freeze_layers = opts.freezem c.G_kwargs.mapping_kwargs.freeze_embed = opts.freezee c.D_kwargs.block_kwargs.freeze_layers = opts.freezed