-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·43 lines (37 loc) · 1.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""Main Entry Point for Package.
Unsupervised Deep-Learning Seminar
LMU Munich
Philipp Koch, 2023
MIT-License
"""
import hydra
from neg_udl.Experiment import Experiment
from omegaconf import DictConfig
import os
import torch
from dotenv import load_dotenv
load_dotenv()
@hydra.main(version_base=None, config_path="neg_udl/config")
def run_experiment(cfg: DictConfig) -> None:
experiment: Experiment = hydra.utils.instantiate(
cfg.experiment,
**{'name': cfg.name,
'model_checkpoint': cfg.model.name,
'dataset_config': dict(cfg.data),
'data_collator': cfg.data_collator,
'seed': cfg.seed,
'num_epochs': cfg.training.epochs,
'batch_size': cfg.training.batch_size,
'lr': cfg.training.lr,
'steps': cfg.training.eval_steps_n,
'eval_steps': cfg.training.eval_steps,
'model_target_path': cfg.model.target_path,
'freeze_layers': (cfg.model.freeze_lower, cfg.model.freeze_upper),
'model_tmp_path': cfg.model.tmp_path}
)
experiment.prepare_dataset()
experiment.run()
if __name__ == "__main__":
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# torch.multiprocessing.set_start_method('spawn')
run_experiment()