forked from DeepPSP/torch_ecg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cfg.py
150 lines (115 loc) · 4.38 KB
/
cfg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
References
----------
[1] Moskalenko, Viktor, Nikolai Zolotykh, and Grigory Osipov. "Deep Learning for ECG Segmentation." International Conference on Neuroinformatics. Springer, Cham, 2019.
"""
from copy import deepcopy
from pathlib import Path
try:
import torch_ecg # noqa: F401
except ModuleNotFoundError:
import sys
sys.path.insert(0, str(Path(__file__).absolute().parents[2]))
from torch_ecg.cfg import CFG, DEFAULTS
from torch_ecg.model_configs import ( # noqa: F401
ECG_SUBTRACT_UNET_CONFIG,
ECG_UNET_VANILLA_CONFIG,
)
from torch_ecg.utils import ecg_arrhythmia_knowledge as EAK
__all__ = [
"TrainCfg",
]
_BASE_DIR = Path(__file__).absolute().parent
BaseCfg = CFG()
BaseCfg.fs = 500 # Hz, LUDB data fs
BaseCfg.classes = [
"p", # pwave
"N", # qrs complex
"t", # twave
"i", # isoelectric
]
# BaseCfg.mask_classes = [
# "p", # pwave
# "N", # qrs complex
# "t", # twave
# ]
BaseCfg.mask_classes = deepcopy(BaseCfg.classes)
BaseCfg.class_map = CFG(p=1, N=2, t=3, i=0)
# BaseCfg.mask_class_map = CFG({k:v-1 for k,v in BaseCfg.class_map.items() if k!="i"})
BaseCfg.mask_class_map = deepcopy(BaseCfg.class_map)
BaseCfg.db_dir = None
BaseCfg.bias_thr = int(0.075 * BaseCfg.fs) # TODO: renew this const
# detected waves that are within `skip_dist` from two ends of the signal will be ignored,
BaseCfg.skip_dist = int(0.5 * BaseCfg.fs)
BaseCfg.torch_dtype = DEFAULTS.torch_dtype
TrainCfg = CFG()
# configs of files
TrainCfg.db_dir = BaseCfg.db_dir
TrainCfg.log_dir = _BASE_DIR / "log"
TrainCfg.checkpoints = _BASE_DIR / "checkpoints"
TrainCfg.log_dir.mkdir(parents=True, exist_ok=True)
TrainCfg.checkpoints.mkdir(parents=True, exist_ok=True)
TrainCfg.keep_checkpoint_max = 20
TrainCfg.torch_dtype = BaseCfg.torch_dtype
TrainCfg.fs = 500
TrainCfg.train_ratio = 0.8
TrainCfg.classes = deepcopy(BaseCfg.classes)
TrainCfg.class_map = deepcopy(BaseCfg.class_map)
TrainCfg.mask_classes = deepcopy(BaseCfg.mask_classes)
TrainCfg.mask_class_map = deepcopy(BaseCfg.mask_class_map)
TrainCfg.skip_dist = BaseCfg.skip_dist
TrainCfg.leads = (
EAK.Standard12Leads
) # ["II",] # the lead to tain model, None --> all leads
TrainCfg.use_single_lead = (
False # use single lead as input or use all leads in `TrainCfg.leads`
)
if TrainCfg.use_single_lead:
TrainCfg.n_leads = 1
else:
TrainCfg.n_leads = len(TrainCfg.leads)
# as for `start_from` and `end_at`, see ref. [1] section 3.1
TrainCfg.start_from = int(2 * TrainCfg.fs)
TrainCfg.end_at = int(2 * TrainCfg.fs)
TrainCfg.input_len = int(4 * TrainCfg.fs)
TrainCfg.over_sampling = 1
# configs of training epochs, batch, etc.
TrainCfg.n_epochs = 300
TrainCfg.batch_size = 32
# TrainCfg.max_batches = 500500
# configs of optimizers and lr_schedulers
TrainCfg.optimizer = "adamw_amsgrad" # "sgd", "adam", "adamw"
TrainCfg.momentum = 0.949 # default values for corresponding PyTorch optimizers
TrainCfg.betas = (0.9, 0.999) # default values for corresponding PyTorch optimizers
TrainCfg.decay = 1e-2 # default values for corresponding PyTorch optimizers
TrainCfg.learning_rate = 1e-3 # 1e-4
TrainCfg.lr = TrainCfg.learning_rate
TrainCfg.lr_scheduler = "one_cycle" # "one_cycle", "plateau", "burn_in", "step", None
TrainCfg.lr_step_size = 50
TrainCfg.lr_gamma = 0.1
TrainCfg.max_lr = 2e-3 # for "one_cycle" scheduler, to adjust via expriments
TrainCfg.burn_in = 400
TrainCfg.steps = [5000, 10000]
TrainCfg.early_stopping = CFG() # early stopping according to challenge metric
TrainCfg.early_stopping.min_delta = 0.001 # should be non-negative
TrainCfg.early_stopping.patience = 30
# configs of loss function
TrainCfg.loss = "FocalLoss" # "BCEWithLogitsLoss", "AsymmetricLoss", "CrossEntropyLoss"
TrainCfg.loss_kw = CFG() # "BCEWithLogitsLoss", "AsymmetricLoss"
TrainCfg.flooding_level = 0.0 # flooding performed if positive
TrainCfg.log_every = 1
TrainCfg.monitor = "f1_score"
TrainCfg.model_name = "unet"
ModelCfg = CFG()
ModelCfg.torch_dtype = BaseCfg.torch_dtype
ModelCfg.fs = BaseCfg.fs
ModelCfg.spacing = 1000 / ModelCfg.fs
ModelCfg.classes = deepcopy(BaseCfg.classes)
ModelCfg.class_map = deepcopy(BaseCfg.class_map)
ModelCfg.mask_classes = deepcopy(BaseCfg.mask_classes)
ModelCfg.mask_class_map = deepcopy(BaseCfg.mask_class_map)
ModelCfg.n_leads = TrainCfg.n_leads
ModelCfg.skip_dist = BaseCfg.skip_dist
ModelCfg.model_name = TrainCfg.model_name
ModelCfg.unet = deepcopy(ECG_UNET_VANILLA_CONFIG)
# TODO: add detailed ModelCfg