-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainer.py
51 lines (44 loc) · 2.01 KB
/
Trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import logging
import argparse
import pyrallis
import os
from GeoNeusTrainer import GeoNeusTrainer
from LatentPaintTrainer import LatentPaintTrainer
from confs.train_config import TrainConfig
os.environ['CURL_CA_BUNDLE'] = ''
@pyrallis.wrap()
def main(cfg: TrainConfig):
print('Hello Wooden')
torch.cuda.set_device(cfg.global_setting.gpu)
if cfg.global_setting.half:
torch.set_default_tensor_type('torch.cuda.HalfTensor')
else:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
if cfg.global_setting.mode == "latent_paint":
trainer = LatentPaintTrainer(cfg)
if cfg.log.eval_only:
trainer.full_eval()
else:
trainer.train()
else:
trainer = GeoNeusTrainer(cfg.neus.neus_cfg_path, cfg.global_setting.mode, cfg.neus.case, cfg.neus.is_continue, cfg.neus.checkpoint, cfg.neus.suffix)
if cfg.global_setting.mode == 'train':
trainer.train()
elif cfg.global_setting.mode == 'validate_mesh':
trainer.validate_mesh(world_space=True, resolution=512, threshold=cfg.neus.mcube_threshold, dilation=cfg.neus.dilation) # 512
# elif args.mode == 'validate_mesh_womask':
# runner.validate_mesh_womask(world_space=True, resolution=512, threshold=args.mcube_threshold, dilation=args.dilation) # 512
# elif args.mode == 'validate_mesh_ori':
# runner.validate_mesh_ori(world_space=True, resolution=512, threshold=args.mcube_threshold) # 512
elif cfg.global_setting.mode == 'validate_image':
trainer.validate_image()
elif cfg.global_setting.mode == 'eval_image':
trainer.eval_image()
elif cfg.global_setting.mode.startswith('interpolate'): # Interpolate views given two image indices
_, img_idx_0, img_idx_1 = cfg.global_setting.mode.split('_')
img_idx_0 = int(img_idx_0)
img_idx_1 = int(img_idx_1)
trainer.interpolate_view(img_idx_0, img_idx_1)
if __name__ == "__main__":
main()