-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
188 lines (165 loc) · 9.14 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import time
import torch
import configargparse
from networks.dm_nerf import get_embedder, DM_NeRF
def config_parser():
parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True, type=str, default='train_configs/replica/office_0.txt',
help='train_configs file path')
parser.add_argument("--expname", type=str, default='office_0',
help='experiment name')
parser.add_argument("--log_time", default=None,
help="save as time level")
parser.add_argument("--basedir", type=str, default='./logs',
help='where to store ckpts and logs')
parser.add_argument("--datadir", type=str, default='./data/replica/office_0',
help='input data directory')
# training options
parser.add_argument("--netdepth", type=int, default=8,
help='layers in network')
parser.add_argument("--netwidth", type=int, default=256,
help='channels per layer')
# 32*32*4
parser.add_argument("--N_train", type=int, default=4096,
help='batch size (number of random rays per gradient step)')
parser.add_argument("--lrate", type=float, default=5e-4,
help='learning rate')
# 250
parser.add_argument("--lrate_decay", type=int, default=500,
help='exponential learning rate decay (in 1000 steps)')
parser.add_argument("--N_test", type=int, default=1024 * 2,
help='number of rays processed in parallel, decrease if running out of memory')
parser.add_argument("--is_train", type=bool, default=True,
help='number of pts sent through network in parallel, decrease if running out of memory')
parser.add_argument("--no_reload", action='store_true',
help='do not reload weights from saved ckpt')
parser.add_argument("--ft_path", type=str, default=None,
help='specific weights npy file to reload for coarse network')
# semantic instance network options
# rendering options
parser.add_argument("--N_samples", type=int, default=64,
help='number of coarse samples per ray')
# 0
parser.add_argument("--N_importance", type=int, default=128,
help='number of additional fine samples per ray')
parser.add_argument("--perturb", type=float, default=1.,
help='set to 0. for no jitter, 1. for jitter')
parser.add_argument("--i_embed", type=int, default=0,
help='set 0 for default positional encoding, -1 for none')
parser.add_argument("--multires", type=int, default=10,
help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--multires_views", type=int, default=4,
help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--render", action='store_true',
help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--render_test", action='store_true',
help='render the test set instead of render_poses path')
parser.add_argument("--render_factor", type=int, default=0,
help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
# training options
parser.add_argument("--precrop_iters", type=int, default=0,
help='number of steps to train on central crops')
parser.add_argument("--precrop_frac", type=float,
default=.5, help='fraction of img taken for central crops')
# datasets options
parser.add_argument("--testskip", type=int, default=10,
help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
parser.add_argument("--resize", action='store_true',
help='will resize image and instance map shape of ScanNet dataset')
## blender flags
parser.add_argument("--white_bkgd", action='store_true',
help='set to render synthetic data on a white bkgd (always use for dvoxels)')
parser.add_argument("--near", type=float,
help='set the nearest depth')
parser.add_argument("--far", type=float,
help='set the farest depth')
parser.add_argument("--crop_width", type=int,
help='set the width of crop')
parser.add_argument("--crop_height", type=int,
help='set the height of crop')
# logging/saving options
parser.add_argument("--i_print", type=int, default=100,
help='frequency of console printout and metric loggin')
parser.add_argument("--i_img", type=int, default=500,
help='frequency of tensorboard image logging')
parser.add_argument("--i_save", type=int, default=10000,
help='frequency of weight ckpt saving')
parser.add_argument("--i_test", type=int, default=50000,
help='frequency of testset saving')
parser.add_argument("--i_video", type=int, default=50000,
help='frequency of render_poses video saving')
# semantic instance options
parser.add_argument("--weakly_mode", type=str, default='weakly_ins',
help="select how to weakly instance labels can be set as weakly_ins, weakly_img, weakly_click")
parser.add_argument("--weakly_value", type=float, default=1.0,
help="select how to weakly instance labels, 0-1")
parser.add_argument("--over_penalize", action='store_true',
help="aim to penalize unlabeled rays to air")
parser.add_argument("--tolerance", type=float, default=None,
help="move the center of Gaussian Distribution to the depth - tolerance")
parser.add_argument("--deta_w", type=float, default=None,
help="contorl the Amplitude of Gaussian Distribution")
# visualizer hyper-parameter
parser.add_argument("--target_label", type=int, default=None,
help='sign the instance you want to move')
parser.add_argument("--center_index", type=int, default=None,
help='sign the instance center wich you want to move')
parser.add_argument("--ori_pose", type=int, default=None,
help='sign the instance center')
# manipulation parameter
parser.add_argument("--mani_demo", action='store_true',
help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--mani_eval", action='store_true',
help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--mani_mode", type=str, default='rotation',
help='select operation mode includes translation, rotation, scale, multi')
parser.add_argument("--views", type=int, default=720,
help="the amount of generated view")
parser.add_argument("--translation", type=bool, default=False,
help='do the shift operation')
parser.add_argument("--rotation", type=bool, default=False,
help='do the rotation operation')
parser.add_argument("--scale", type=bool, default=False,
help='do the scale down or scale up operation')
# 3D color mesh extraction
parser.add_argument("--mesh", action='store_true',
help='generate a 3d point cloud file')
return parser
def create_nerf(args):
"""
Instantiate NeRF's MLP model.
"""
position_embedder, input_ch_pos = get_embedder(args.multires, args.i_embed)
view_embedder, input_ch_view = get_embedder(args.multires_views, args.i_embed)
model_coarse = \
DM_NeRF(args.netdepth, args.netwidth, input_ch_pos, input_ch_view, [4], args.ins_num).to(args.device)
print(model_coarse)
model_fine = \
DM_NeRF(args.netdepth, args.netwidth, input_ch_pos, input_ch_view, [4], args.ins_num).to(args.device)
print(model_fine)
return position_embedder, view_embedder, model_coarse, model_fine, args
def initial():
parser = config_parser()
args = parser.parse_args()
# get log time
if args.log_time is None:
args.log_time = time.strftime("%Y%m%d%H%M", time.localtime())
if torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
args.device = torch.device("cuda:0")
else:
args.device = torch.device("cpu")
log_dir = os.path.join(args.basedir, args.expname, args.log_time)
print(log_dir)
os.makedirs(log_dir, exist_ok=True)
f = os.path.join(log_dir, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(log_dir, 'configs.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
return args