-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
95 lines (69 loc) · 2.77 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from pathlib import Path
import os
import fire
import unisal
def train(eval_sources=('SALICON',),
**kwargs):
"""Run training and evaluation."""
trainer = unisal.train.Trainer(**kwargs)
trainer.fit()
for source in eval_sources:
trainer.score_model(source=source)
trainer.export_scalars()
trainer.writer.close()
def load_trainer(train_id=None):
"""Instantiate Trainer class from saved kwargs."""
if train_id is None:
train_id = 'pretrained_unisal'
print(f"Train ID: {train_id}")
train_dir = Path(os.environ["TRAIN_DIR"])
train_dir = train_dir / train_id
return unisal.train.Trainer.init_from_cfg_dir(train_dir)
def score_model(
#train_id='2020-09-30_14:35:37_unisal',
train_id=None,
sources=('SALICON',),
**kwargs):
"""Compute the scores for a trained model."""
trainer = load_trainer(train_id)
for source in sources:
trainer.score_model(source=source, **kwargs)
def generate_predictions(
train_id='2020-09-23_09:50:34_unisal',
sources=('SALICON',),
**kwargs):
"""Generate predictions with a trained model."""
trainer = load_trainer(train_id)
for source in sources:
# Load fine-tuned weights for MIT datasets
if source in ('MIT1003', 'MIT300'):
trainer.model.load_weights(trainer.train_dir, "ft_mit1003")
trainer.salicon_cfg['x_val_step'] = 0
kwargs.update({'model_domain': 'SALICON', 'load_weights': False})
trainer.generate_predictions(source=source, **kwargs)
def predictions_from_folder(
folder_path, is_video, source=None, train_id=None, model_domain=None):
"""Generate predictions of files in a folder with a trained model."""
trainer = load_trainer(train_id)
trainer.generate_predictions_from_path(
folder_path, is_video, source=source, model_domain=model_domain)
def predict_examples(train_id='2020-09-21_10:36:24_unisal'):
for example_folder in (Path(__file__).resolve().parent / "data").glob("*"):
if not example_folder.is_dir():
continue
source = example_folder.name
is_video = source not in ('SALICON', 'MIT1003')
print(f"\nGenerating predictions for {'video' if is_video else 'image'} "
f"folder\n{str(source)}")
if is_video:
if not example_folder.is_dir():
continue
for video_folder in example_folder.glob('*'):
predictions_from_folder(
video_folder, is_video, train_id=train_id, source=source)
else:
predictions_from_folder(
example_folder, is_video, train_id=train_id, source=source)
if __name__ == "__main__":
train()
#score_model()