From dd9af6ef156e94e2baebdbc879f80a481ed76cd6 Mon Sep 17 00:00:00 2001 From: quantylab Date: Tue, 26 Jan 2021 12:43:28 +0000 Subject: [PATCH] =?UTF-8?q?=EC=9B=90=EC=88=AD=EC=9D=B4=ED=88=AC=EC=9E=90?= =?UTF-8?q?=20=EC=8B=A4=ED=96=89=20=EA=B0=9C=EC=84=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent.py | 6 ++++-- learners.py | 1 + main.py | 15 +++++++++++---- networks.py | 4 +++- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/agent.py b/agent.py index 6f0ed15..79b5984 100644 --- a/agent.py +++ b/agent.py @@ -65,8 +65,10 @@ def reset(self): self.ratio_hold = 0 self.ratio_portfolio_value = 0 - def reset_exploration(self): - self.exploration_base = 0.5 + np.random.rand() / 2 + def reset_exploration(self, alpha=None): + if alpha is None: + alpha = np.random.rand() / 2 + self.exploration_base = 0.5 + alpha def set_balance(self, balance): self.initial_balance = balance diff --git a/learners.py b/learners.py index 46cc821..05bda3a 100644 --- a/learners.py +++ b/learners.py @@ -287,6 +287,7 @@ def run( self.agent.reset_exploration() else: epsilon = start_epsilon + self.agent.reset_exploration(alpha=0) while True: # 샘플 생성 diff --git a/main.py b/main.py index c38dc04..6847348 100644 --- a/main.py +++ b/main.py @@ -14,9 +14,9 @@ parser.add_argument('--stock_code', nargs='+') parser.add_argument('--ver', choices=['v1', 'v2'], default='v2') parser.add_argument('--rl_method', - choices=['dqn', 'pg', 'ac', 'a2c', 'a3c']) + choices=['dqn', 'pg', 'ac', 'a2c', 'a3c', 'monkey']) parser.add_argument('--net', - choices=['dnn', 'lstm', 'cnn'], default='dnn') + choices=['dnn', 'lstm', 'cnn', 'monkey'], default='dnn') parser.add_argument('--num_steps', type=int, default=1) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--discount_factor', type=float, default=0.9) @@ -63,8 +63,8 @@ # 로그, Keras Backend 설정을 먼저하고 RLTrader 모듈들을 이후에 임포트해야 함 from agent import Agent - from learners import DQNLearner, PolicyGradientLearner, \ - ActorCriticLearner, A2CLearner, A3CLearner + from learners import ReinforcementLearner, DQNLearner, \ + PolicyGradientLearner, ActorCriticLearner, A2CLearner, A3CLearner # 모델 경로 준비 value_network_path = '' @@ -130,6 +130,13 @@ learner = A2CLearner(**{**common_params, 'value_network_path': value_network_path, 'policy_network_path': policy_network_path}) + elif args.rl_method == 'monkey': + args.net = args.rl_method + args.num_epoches = 1 + args.discount_factor = None + args.start_epsilon = 1 + args.learning = False + learner = ReinforcementLearner(**common_params) if learner is not None: learner.run(balance=args.balance, num_epoches=args.num_epoches, diff --git a/networks.py b/networks.py index 1107945..6bd1808 100644 --- a/networks.py +++ b/networks.py @@ -23,7 +23,9 @@ def set_session(sess): pass from tensorflow.keras.backend import set_session import tensorflow as tf graph = tf.get_default_graph() - sess = tf.compat.v1.Session() + config = tf.compat.v1.ConfigProto() + config.gpu_options.allow_growth = True + sess = tf.compat.v1.Session(config=config) elif os.environ['KERAS_BACKEND'] == 'plaidml.keras.backend': from keras.models import Model from keras.layers import Input, Dense, LSTM, Conv2D, \