Skip to content

Commit

Permalink
Merge pull request #1 from samre12/dev
Browse files Browse the repository at this point in the history
Refactored code for a base cryptotrading env
  • Loading branch information
samre12 authored May 26, 2018
2 parents 2d83de6 + 632dc63 commit 14850a0
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 105 deletions.
1 change: 0 additions & 1 deletion gym_cryptotrading/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from gym.envs.registration import register

register(
Expand Down
109 changes: 23 additions & 86 deletions gym_cryptotrading/envs/basicenv.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,51 @@
import numpy as np
from abc import ABCMeta, abstractmethod

import gym
from gym import error, logger
from gym import logger

from abc import abstractmethod
class BaseEnv:
'''
Abstract Base Class for CryptoTrading Environments
'''

from gym_cryptotrading.generator import Generator
from gym_cryptotrading.strings import *
from gym_cryptotrading.errors import *

from gym_cryptotrading.spaces.action import ActionSpace
from gym_cryptotrading.spaces.observation import ObservationSpace

class BaseEnv(gym.Env):
action_space = ActionSpace()
observation_space = ObservationSpace()
metadata = {'render.modes': []}
__metaclass__ = ABCMeta

def __init__(self):
self.episode_number = 0
self.generator = None
self.logger = logger

self.history_length = 100
self.horizon = 5
self.unit = 5e-4

@abstractmethod
def _set_env_specific_params(self, **kwargs):
pass
def _get_new_state(self):
raise NotImplementedError

def _load_gen(self):
if not self.generator:
self.generator = Generator(self.history_length, self.horizon)
@abstractmethod
def _get_reward(self):
raise NotImplementedError

@abstractmethod
def _new_random_episode(self):
'''
TODO: In the current setting, the selection of an episode does not follow pure uniform process.
Need to index every episode and then generate a random index rather than going on multiple levels
of selection.
'''
self._load_gen()
self._reset_params()
message_list = []
self.episode_number = self.episode_number + 1
message_list.append("Starting a new episode numbered {}".format(self.episode_number))

block_index = np.random.randint(0, len(self.generator.price_blocks) - 1)
message_list.append("Block index selected for episode number {} is {}".format(
self.episode_number, block_index
)
)

self.diffs = self.generator.diff_blocks[block_index]
self.historical_prices = self.generator.price_blocks[block_index]

self.current = np.random.randint(self.history_length,
len(self.historical_prices) - self.horizon)
message_list.append(
"Starting index and timestamp point selected for episode number {} is {}:==:{}".format(
self.episode_number,
self.current,
self.generator.timestamp_blocks[block_index][self.current]
)
)

map(self.logger.debug, message_list)

return self.historical_prices[self.current - self.history_length:self.current]
raise NotImplementedError

@abstractmethod
def _reset_params(self):
pass
raise NotImplementedError

@abstractmethod
def _take_action(self, action):
pass
def _set_env_specific_params(self, **kwargs):
raise NotImplementedError

@abstractmethod
def _get_reward(self):
return 0
def _take_action(self, action):
raise NotImplementedError

def _get_new_state(self):
return self.historical_prices[self.current]
@abstractmethod
def set_params(self, history_length, horizon, unit, **kwargs):
raise NotImplementedError

def set_logger(self, custom_logger):
if custom_logger:
self.logger = custom_logger

def set_params(self, history_length, horizon, unit, **kwargs):
if self.generator:
raise EnvironmentAlreadyLoaded()

if history_length < 0 or horizon < 1 or unit < 0:
raise ValueError()

else:
self.history_length = history_length
self.horizon = horizon
self.unit = unit #units of Bitcoin traded each time

self._set_env_specific_params(**kwargs)

def reset(self):
return self._new_random_episode()

@abstractmethod
def step(self, action):
state = self._get_new_state()
self._take_action(action)
reward = self._get_reward()
return state, reward, False, None

91 changes: 91 additions & 0 deletions gym_cryptotrading/envs/cryptoenv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np

import gym
from gym import error, logger

from abc import abstractmethod

from gym_cryptotrading.envs.basicenv import BaseEnv

from gym_cryptotrading.generator import Generator
from gym_cryptotrading.strings import *
from gym_cryptotrading.errors import *

from gym_cryptotrading.spaces.action import ActionSpace
from gym_cryptotrading.spaces.observation import ObservationSpace

class CryptoEnv(gym.Env, BaseEnv):
action_space = ActionSpace()
observation_space = ObservationSpace()
metadata = {'render.modes': []}

def __init__(self):
super(CryptoEnv, self).__init__()
self.generator = None

def _get_new_state(self):
return self.historical_prices[self.current]

def _load_gen(self):
if not self.generator:
self.generator = Generator(self.history_length, self.horizon)

def _new_random_episode(self):
'''
TODO: In the current setting, the selection of an episode does not follow pure uniform process.
Need to index every episode and then generate a random index rather than going on multiple levels
of selection.
'''
self._load_gen()
self._reset_params()
message_list = []
self.episode_number = self.episode_number + 1
message_list.append("Starting a new episode numbered {}".format(self.episode_number))

block_index = np.random.randint(0, len(self.generator.price_blocks) - 1)
message_list.append("Block index selected for episode number {} is {}".format(
self.episode_number, block_index
)
)

self.diffs = self.generator.diff_blocks[block_index]
self.historical_prices = self.generator.price_blocks[block_index]

self.current = np.random.randint(self.history_length,
len(self.historical_prices) - self.horizon)
message_list.append(
"Starting index and timestamp point selected for episode number {} is {}:==:{}".format(
self.episode_number,
self.current,
self.generator.timestamp_blocks[block_index][self.current]
)
)

map(self.logger.debug, message_list)

return self.historical_prices[self.current - self.history_length:self.current]


def _reset_params(self):
pass

def _set_env_specific_params(self, **kwargs):
pass

def reset(self):
return self._new_random_episode()

def set_params(self, history_length, horizon, unit, **kwargs):
if self.generator:
raise EnvironmentAlreadyLoaded()

if history_length < 0 or horizon < 1 or unit < 0:
raise ValueError()

else:
self.history_length = history_length
self.horizon = horizon
self.unit = unit #units of Bitcoin traded each time

self._set_env_specific_params(**kwargs)

12 changes: 6 additions & 6 deletions gym_cryptotrading/envs/realizedPnL.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from gym import error

from gym_cryptotrading.strings import *
from gym_cryptotrading.envs.basicenv import BaseEnv
from gym_cryptotrading.envs.cryptoenv import CryptoEnv

class RealizedPnLEnv(BaseEnv):
class RealizedPnLEnv(CryptoEnv):
def __init__(self):
super(RealizedPnLEnv, self).__init__()

Expand All @@ -15,13 +15,13 @@ def _reset_params(self):
self.reward = 0.0

def _take_action(self, action):
if action not in BaseEnv.action_space.lookup.keys():
if action not in CryptoEnv.action_space.lookup.keys():
raise error.InvalidAction()
else:
if BaseEnv.action_space.lookup[action] is LONG:
if CryptoEnv.action_space.lookup[action] is LONG:
self.long = self.long + 1

elif BaseEnv.action_space.lookup[action] is SHORT:
elif CryptoEnv.action_space.lookup[action] is SHORT:
self.short = self.short + 1

def _get_reward(self):
Expand All @@ -41,7 +41,7 @@ def step(self, action):
reward = self._get_reward()

message = "Timestep {}:==: Action: {} ; Reward: {}".format(
self.timesteps, BaseEnv.action_space.lookup[action], reward
self.timesteps, CryptoEnv.action_space.lookup[action], reward
)
self.logger.debug(message)

Expand Down
12 changes: 6 additions & 6 deletions gym_cryptotrading/envs/unrealizedPnL.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from gym import error

from gym_cryptotrading.strings import *
from gym_cryptotrading.envs.basicenv import BaseEnv
from gym_cryptotrading.envs.cryptoenv import CryptoEnv

class UnRealizedPnLEnv(BaseEnv):
class UnRealizedPnLEnv(CryptoEnv):
def __init__(self):
super(UnRealizedPnLEnv, self).__init__()

Expand All @@ -14,13 +14,13 @@ def _reset_params(self):
self.timesteps = 0

def _take_action(self, action):
if action not in BaseEnv.action_space.lookup.keys():
if action not in CryptoEnv.action_space.lookup.keys():
raise error.InvalidAction()
else:
if BaseEnv.action_space.lookup[action] is LONG:
if CryptoEnv.action_space.lookup[action] is LONG:
self.long = self.long + 1

elif BaseEnv.action_space.lookup[action] is SHORT:
elif CryptoEnv.action_space.lookup[action] is SHORT:
self.short = self.short + 1

def _get_reward(self):
Expand All @@ -35,7 +35,7 @@ def step(self, action):
reward = self._get_reward()

message = "Timestep {}:==: Action: {} ; Reward: {}".format(
self.timesteps, BaseEnv.action_space.lookup[action], reward
self.timesteps, CryptoEnv.action_space.lookup[action], reward
)
self.logger.debug(message)

Expand Down
12 changes: 6 additions & 6 deletions gym_cryptotrading/envs/weightedPnL.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from gym import error

from gym_cryptotrading.strings import *
from gym_cryptotrading.envs.basicenv import BaseEnv
from gym_cryptotrading.envs.cryptoenv import CryptoEnv

class ExponentiallyWeightedReward:
def __init__(self, lag, decay_rate):
Expand Down Expand Up @@ -36,7 +36,7 @@ def insert(self, reward):
def reward(self):
return self.sum / self.denominator

class WeightedPnLEnv(BaseEnv):
class WeightedPnLEnv(CryptoEnv):
def __init__(self):
super(WeightedPnLEnv, self).__init__()

Expand All @@ -63,13 +63,13 @@ def _reset_params(self):
self.reward = ExponentiallyWeightedReward(self.lag, self.decay_rate)

def _take_action(self, action):
if action not in BaseEnv.action_space.lookup.keys():
if action not in CryptoEnv.action_space.lookup.keys():
raise error.InvalidAction()
else:
if BaseEnv.action_space.lookup[action] is LONG:
if CryptoEnv.action_space.lookup[action] is LONG:
self.long = self.long + 1

elif BaseEnv.action_space.lookup[action] is SHORT:
elif CryptoEnv.action_space.lookup[action] is SHORT:
self.short = self.short + 1

def _get_reward(self):
Expand All @@ -86,7 +86,7 @@ def step(self, action):
reward = self._get_reward()

message = "Timestep {}:==: Action: {} ; Reward: {}".format(
self.timesteps, BaseEnv.action_space.lookup[action], reward
self.timesteps, CryptoEnv.action_space.lookup[action], reward
)
self.logger.debug(message)

Expand Down

0 comments on commit 14850a0

Please sign in to comment.