Skip to content

Commit

Permalink
원숭이투자 실행 개선
Browse files Browse the repository at this point in the history
  • Loading branch information
quantylab committed Jan 26, 2021
1 parent 7cb2a40 commit dd9af6e
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
6 changes: 4 additions & 2 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ def reset(self):
self.ratio_hold = 0
self.ratio_portfolio_value = 0

def reset_exploration(self):
self.exploration_base = 0.5 + np.random.rand() / 2
def reset_exploration(self, alpha=None):
if alpha is None:
alpha = np.random.rand() / 2
self.exploration_base = 0.5 + alpha

def set_balance(self, balance):
self.initial_balance = balance
Expand Down
1 change: 1 addition & 0 deletions learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def run(
self.agent.reset_exploration()
else:
epsilon = start_epsilon
self.agent.reset_exploration(alpha=0)

while True:
# 샘플 생성
Expand Down
15 changes: 11 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
parser.add_argument('--stock_code', nargs='+')
parser.add_argument('--ver', choices=['v1', 'v2'], default='v2')
parser.add_argument('--rl_method',
choices=['dqn', 'pg', 'ac', 'a2c', 'a3c'])
choices=['dqn', 'pg', 'ac', 'a2c', 'a3c', 'monkey'])
parser.add_argument('--net',
choices=['dnn', 'lstm', 'cnn'], default='dnn')
choices=['dnn', 'lstm', 'cnn', 'monkey'], default='dnn')
parser.add_argument('--num_steps', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--discount_factor', type=float, default=0.9)
Expand Down Expand Up @@ -63,8 +63,8 @@

# 로그, Keras Backend 설정을 먼저하고 RLTrader 모듈들을 이후에 임포트해야 함
from agent import Agent
from learners import DQNLearner, PolicyGradientLearner, \
ActorCriticLearner, A2CLearner, A3CLearner
from learners import ReinforcementLearner, DQNLearner, \
PolicyGradientLearner, ActorCriticLearner, A2CLearner, A3CLearner

# 모델 경로 준비
value_network_path = ''
Expand Down Expand Up @@ -130,6 +130,13 @@
learner = A2CLearner(**{**common_params,
'value_network_path': value_network_path,
'policy_network_path': policy_network_path})
elif args.rl_method == 'monkey':
args.net = args.rl_method
args.num_epoches = 1
args.discount_factor = None
args.start_epsilon = 1
args.learning = False
learner = ReinforcementLearner(**common_params)
if learner is not None:
learner.run(balance=args.balance,
num_epoches=args.num_epoches,
Expand Down
4 changes: 3 additions & 1 deletion networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def set_session(sess): pass
from tensorflow.keras.backend import set_session
import tensorflow as tf
graph = tf.get_default_graph()
sess = tf.compat.v1.Session()
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
elif os.environ['KERAS_BACKEND'] == 'plaidml.keras.backend':
from keras.models import Model
from keras.layers import Input, Dense, LSTM, Conv2D, \
Expand Down

0 comments on commit dd9af6e

Please sign in to comment.