forked from torch-points3d/torch-points3d
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
148 lines (118 loc) · 4.12 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import hydra
import logging
from omegaconf import OmegaConf
# Import building function for model and dataset
from torch_points3d.datasets.dataset_factory import instantiate_dataset
from torch_points3d.models.model_factory import instantiate_model
# Import BaseModel / BaseDataset for type checking
from torch_points3d.models.base_model import BaseModel
from torch_points3d.datasets.base_dataset import BaseDataset
# Import from metrics
from torch_points3d.metrics.base_tracker import BaseTracker
from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq
from torch_points3d.metrics.model_checkpoint import ModelCheckpoint
# Utils import
from torch_points3d.utils.model_building_utils.model_definition_resolver import resolve_model
from torch_points3d.utils.colors import COLORS
log = logging.getLogger(__name__)
def eval_epoch(
model: BaseModel,
dataset,
device,
tracker: BaseTracker,
checkpoint: ModelCheckpoint,
voting_runs=1,
tracker_options={},
):
tracker.reset("val")
loader = dataset.val_dataloader
for i in range(voting_runs):
with Ctq(loader) as tq_val_loader:
for data in tq_val_loader:
with torch.no_grad():
model.set_input(data, device)
model.forward()
tracker.track(model, **tracker_options)
tq_val_loader.set_postfix(**tracker.get_metrics(), color=COLORS.VAL_COLOR)
tracker.finalise(**tracker_options)
tracker.print_summary()
def test_epoch(
model: BaseModel,
dataset,
device,
tracker: BaseTracker,
checkpoint: ModelCheckpoint,
voting_runs=1,
tracker_options={},
):
loaders = dataset.test_dataloaders
for loader in loaders:
stage_name = loader.dataset.name
tracker.reset(stage_name)
for i in range(voting_runs):
with Ctq(loader) as tq_test_loader:
for data in tq_test_loader:
with torch.no_grad():
model.set_input(data, device)
model.forward()
tracker.track(model, **tracker_options)
tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR)
tracker.finalise(**tracker_options)
tracker.print_summary()
def run(
cfg,
model,
dataset: BaseDataset,
device,
tracker: BaseTracker,
checkpoint: ModelCheckpoint,
voting_runs=1,
tracker_options={},
):
if dataset.has_val_loader:
eval_epoch(
model, dataset, device, tracker, checkpoint, voting_runs=voting_runs, tracker_options=tracker_options
)
if dataset.has_test_loaders:
test_epoch(
model, dataset, device, tracker, checkpoint, voting_runs=voting_runs, tracker_options=tracker_options
)
@hydra.main(config_path="conf/eval.yaml")
def main(cfg):
OmegaConf.set_struct(cfg, False)
# Get device
device = torch.device("cuda" if (torch.cuda.is_available() and cfg.cuda) else "cpu")
log.info("DEVICE : {}".format(device))
# Enable CUDNN BACKEND
torch.backends.cudnn.enabled = cfg.enable_cudnn
# Checkpoint
checkpoint = ModelCheckpoint(cfg.checkpoint_dir, cfg.model_name, cfg.weight_name, strict=True)
# Create model and datasets
dataset = instantiate_dataset(checkpoint.data_config)
model = checkpoint.create_model(dataset, weight_name=cfg.weight_name)
log.info(model)
log.info("Model size = %i", sum(param.numel() for param in model.parameters() if param.requires_grad))
# Set dataloaders
dataset.create_dataloaders(
model, cfg.batch_size, cfg.shuffle, cfg.num_workers, cfg.precompute_multi_scale,
)
log.info(dataset)
model.eval()
if cfg.enable_dropout:
model.enable_dropout_in_eval()
model = model.to(device)
tracker: BaseTracker = dataset.get_tracker(False, False)
# Run training / evaluation
run(
cfg,
model,
dataset,
device,
tracker,
checkpoint,
voting_runs=cfg.voting_runs,
tracker_options=cfg.tracker_options,
)
if __name__ == "__main__":
main()