-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmain.py
36 lines (34 loc) · 1.76 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import argparse
import os
import time
from experiments import BCI2aExperiment, PhysionetExperiment
from utils import read_yaml, get_logger
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='bci2a', choices=['bci2a', 'physionet'],
help='data set used of the experiments')
parser.add_argument('--model', type=str, default='EEGNet',
choices=['EEGNet', 'EEGConformer', 'ATCNet', 'EEGInception', 'EEGITNet'],
help='model used of the experiments')
parser.add_argument('--config', type=str, default='default', help='config file name(.yaml format)')
parser.add_argument('--strategy', type=str, default='within-subject', choices=['cross-subject', 'within-subject'],
help='experiments strategy on subjects')
parser.add_argument('--save', action='store_true', help='save the pytorch model and history(follow skorch)')
args = parser.parse_args()
# suit default config for specific dataset and model
if args.config == 'default':
args.config = f'{args.dataset}_{args.model}_{args.config}.yaml'
# read config from yaml file
config = read_yaml(f"{os.getcwd()}/config/{args.config}")
# result save directory
save_dir = f"{os.getcwd()}/save/{int(time.time())}_{args.dataset}_{args.model}/"
args.save_dir = save_dir
logger = get_logger(save_result=True, save_dir=save_dir, save_file='result.log')
logger.info(config)
# for every dataset
if args.dataset == 'bci2a':
exp = BCI2aExperiment(args=args, config=config, logger=logger)
exp.run()
elif args.dataset == 'physionet':
exp = PhysionetExperiment(args=args, config=config, logger=logger)
exp.run()