-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathcfg.py
152 lines (146 loc) · 4.23 KB
/
cfg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# -*- coding: utf-8 -*-
# @Date : 2019-07-25
# @Author : Xinyu Gong ([email protected])
# @Link : None
# @Version : 0.0
import argparse
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--max_epoch',
type=int,
default=200,
help='number of epochs of training')
parser.add_argument(
'--max_iter',
type=int,
default=None,
help='set the max iteration number')
parser.add_argument(
'-gen_bs',
'--gen_batch_size',
type=int,
default=64,
help='size of the batches')
parser.add_argument(
'-dis_bs',
'--dis_batch_size',
type=int,
default=64,
help='size of the batches')
parser.add_argument(
'--g_lr',
type=float,
default=0.0002,
help='adam: gen learning rate')
parser.add_argument(
'--d_lr',
type=float,
default=0.0002,
help='adam: disc learning rate')
parser.add_argument(
'--lr_decay',
action='store_true',
help='learning rate decay or not')
parser.add_argument(
'--beta1',
type=float,
default=0.0,
help='adam: decay of first order momentum of gradient')
parser.add_argument(
'--beta2',
type=float,
default=0.9,
help='adam: decay of first order momentum of gradient')
parser.add_argument(
'--num_workers',
type=int,
default=8,
help='number of cpu threads to use during batch generation')
parser.add_argument(
'--latent_dim',
type=int,
default=128,
help='dimensionality of the latent space')
parser.add_argument(
'--img_size',
type=int,
default=32,
help='size of each image dimension')
parser.add_argument(
'--channels',
type=int,
default=3,
help='number of image channels')
parser.add_argument(
'--n_critic',
type=int,
default=1,
help='number of training steps for discriminator per iter')
parser.add_argument(
'--val_freq',
type=int,
default=20,
help='interval between each validation')
parser.add_argument(
'--print_freq',
type=int,
default=50,
help='interval between each verbose')
parser.add_argument(
'--load_path',
type=str,
help='The reload model path')
parser.add_argument(
'--exp_name',
type=str,
help='The name of exp')
parser.add_argument(
'--d_spectral_norm',
type=str2bool,
default=False,
help='add spectral_norm on discriminator?')
parser.add_argument(
'--g_spectral_norm',
type=str2bool,
default=False,
help='add spectral_norm on generator?')
parser.add_argument(
'--dataset',
type=str,
default='cifar10',
help='dataset type')
parser.add_argument(
'--data_path',
type=str,
default='./data',
help='The path of data set')
parser.add_argument('--init_type', type=str, default='normal',
choices=['normal', 'orth', 'xavier_uniform', 'false'],
help='The init type')
parser.add_argument('--gf_dim', type=int, default=64,
help='The base channel num of gen')
parser.add_argument('--df_dim', type=int, default=64,
help='The base channel num of disc')
parser.add_argument(
'--model',
type=str,
default='sngan_cifar10',
help='path of model')
parser.add_argument('--eval_batch_size', type=int, default=100)
parser.add_argument('--num_eval_imgs', type=int, default=50000)
parser.add_argument(
'--bottom_width',
type=int,
default=4,
help="the base resolution of the GAN")
parser.add_argument('--random_seed', type=int, default=12345)
opt = parser.parse_args()
return opt