-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample_medigan.py
81 lines (71 loc) · 2.59 KB
/
sample_medigan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# import medigan and initialize Generators
import os
from pathlib import Path
import torch
from medigan import Generators # pip install medigan
generators = Generators()
MODEL_IDS = {
"dcgan_bcdr": 5,
"wgangp_bcdr": 6,
"cdgan_bcdr": 12,
}
if __name__ == "__main__":
samples_sizes = {"train": 10000, "val": 1000, "test": 1000}
num_repeat = 5 # number of times to repeat the sampling process
for num in range(num_repeat):
for setting in samples_sizes.keys():
n = samples_sizes[setting]
data_path = Path("data") / f"breastmass_{num}"
for model in MODEL_IDS.keys():
(data_path / model / setting).mkdir(exist_ok=True, parents=True)
for model in MODEL_IDS.keys():
print("Generating {} samples for {}...".format(n, model))
model_id = MODEL_IDS[model]
generators.generate(
model_id=model_id,
num_samples=n,
install_dependencies=True,
output_path=data_path / model / setting,
)
# save generative models
for model in MODEL_IDS.keys():
model_id = MODEL_IDS[model]
me = generators.get_model_executor(model_id=model_id, install_dependencies=True)
if model_id == 12:
model_instance = (
me.deserialized_model_as_lib.Generator(
nz=100,
ngf=64,
nc=2,
ngpu=1,
image_size=128,
leakiness=0.1,
conditional=True,
)
.cuda()
.eval()
)
else:
model_instance = (
me.deserialized_model_as_lib.Generator(
nz=100,
ngf=64,
nc=1,
ngpu=1,
image_size=128,
leakiness=0.1,
conditional=False,
)
.cuda()
.eval()
)
model_instance.load_state_dict(
state_dict=torch.load(me.package_path, map_location="cuda")["generator"]
)
os.makedirs(f"trained_models/{model}/", exist_ok=True)
# save model (for testing whether wrapper is working correctly)
# torch.save(model_instance, f"trained_models/{model}/model.pt")
# save model state dict
torch.save(
model_instance.state_dict(), f"trained_models/{model}/model_state_dict.pt"
)