forked from maschulz/deeperbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
51 lines (43 loc) · 2.02 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import logging
import os
import warnings
from sklearn.exceptions import DataConversionWarning, ConvergenceWarning
from lib.config import Config
from lib.data import DATA
from lib.grid import Grid, GRIDS
from lib.models import MODELS
from lib.preprocessing import TRAFOS, SCALING
from lib.trial import run
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=DataConversionWarning)
warnings.simplefilter(action='ignore', category=ConvergenceWarning)
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluate dataset x model combination.')
parser.add_argument('--data', default='mnist', type=str, choices=list(DATA.keys()))
parser.add_argument('--model', default='majority', type=str, choices=list(MODELS.keys()))
parser.add_argument('--grid', default='v1', type=str, choices=list(GRIDS.keys()))
parser.add_argument('--trafo', default='identity', type=str, choices=list(TRAFOS.keys()))
parser.add_argument('--scaling', default='standard', type=str, choices=list(SCALING.keys()))
parser.add_argument('--dim', default=784, type=int)
parser.add_argument('--seeds', nargs='+', default=range(10), type=int)
parser.add_argument('--sample_sizes', nargs='+', type=int)
parser.add_argument('--hyperopt', action="store_true", default=False)
args = parser.parse_args()
# TODO: set up folder structure if necessary
if not os.path.exists('results'):
os.makedirs('results')
if not os.path.exists('nn_weights'):
os.makedirs('nn_weights')
config = Config(DATA[args.data],
SCALING[args.scaling](),
TRAFOS[args.trafo](n_components=args.dim),
MODELS[args.model],
Grid(args.model, args.grid),
seeds=args.seeds,
sample_sizes=args.sample_sizes,
hyperopt=args.hyperopt
)
run(config)