-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
47 lines (37 loc) · 1.07 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import tensorflow as tf
import os
import argparse
from game_runner import GameRunner
def main():
parser = argparse.ArgumentParser(description='Deep Q Learning')
parser.add_argument(
'-m',
'--model-dir',
required=False,
default=None,
help='Path to a new or existing model directory.'
)
parser.add_argument(
'-t',
'--train-model',
required=False,
default=True,
action='store_true',
help='If true, model will be trained. Otherwise, it will be evaluated.'
)
parser.add_argument(
'-d',
'--default-config',
required=False,
default=os.path.join('default_configs', 'CartPole-v0.yaml'),
help='The default config to use when creating a new model.'
)
args = parser.parse_args()
with tf.Session() as session:
game_runner = GameRunner(session, args.default_config, args.model_dir)
if args.train_model:
game_runner.train()
else:
game_runner.evaluation()
if __name__ == '__main__':
main()