-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
66 lines (48 loc) · 1.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
from runner.base import BaseExperiment
from runner.text2img import Text2ImgExperiment
from runner.unconditional import UnconditionalExperiment
from tools.config_utils import init_experiment_config, override_phase_config, parser_override_args
experiments = {
"base": BaseExperiment,
"unconditional": UnconditionalExperiment,
"text2img": Text2ImgExperiment,
}
def parser_args():
parser = argparse.ArgumentParser("Diffusion config")
parser.add_argument("--config", "-c", type=str, required=True, help="path to config file")
parser.add_argument(
"--phase",
"-p",
type=str,
required=True,
choices=["train", "inference", "sample"],
)
args, kwargs = parser.parse_known_args()
return args, kwargs
def main():
args, kwargs = parser_args()
config = init_experiment_config(args.config)
# update phase config
config.update({"phase": args.phase})
# update kwargs
# the 'a.b=True' to {'a': {'b': True}}
config = override_phase_config(config)
config = parser_override_args(config, kwargs)
experiment_cls = experiments[config.experiment_name]
assert experiment_cls is not None, f"Experiment {config.experiment_name} not found"
# create experiment instance
experiment_instance = experiment_cls(config)
assert experiment_instance is not None, f"Experiment {config.experiment_name} not initialized"
phase = config.phase
if phase == "train":
experiment_instance.train()
elif phase == "inference":
experiment_instance.inference()
elif phase == "sample":
experiment_instance.sample()
elif phase == "test":
print("The test phase is not implemented in the current code. Please use an external script to test the model.")
print("the image test script is in the evaluations folder")
if __name__ == "__main__":
main()