diff --git a/gym_cryptotrading/__init__.py b/gym_cryptotrading/__init__.py index cc92b3b..31d77cf 100644 --- a/gym_cryptotrading/__init__.py +++ b/gym_cryptotrading/__init__.py @@ -1,4 +1,3 @@ -import logging from gym.envs.registration import register register( diff --git a/gym_cryptotrading/envs/basicenv.py b/gym_cryptotrading/envs/basicenv.py index 452a022..90fbbd6 100644 --- a/gym_cryptotrading/envs/basicenv.py +++ b/gym_cryptotrading/envs/basicenv.py @@ -1,25 +1,16 @@ -import numpy as np +from abc import ABCMeta, abstractmethod -import gym -from gym import error, logger +from gym import logger -from abc import abstractmethod +class BaseEnv: + ''' + Abstract Base Class for CryptoTrading Environments + ''' -from gym_cryptotrading.generator import Generator -from gym_cryptotrading.strings import * -from gym_cryptotrading.errors import * - -from gym_cryptotrading.spaces.action import ActionSpace -from gym_cryptotrading.spaces.observation import ObservationSpace - -class BaseEnv(gym.Env): - action_space = ActionSpace() - observation_space = ObservationSpace() - metadata = {'render.modes': []} + __metaclass__ = ABCMeta def __init__(self): self.episode_number = 0 - self.generator = None self.logger = logger self.history_length = 100 @@ -27,88 +18,34 @@ def __init__(self): self.unit = 5e-4 @abstractmethod - def _set_env_specific_params(self, **kwargs): - pass + def _get_new_state(self): + raise NotImplementedError - def _load_gen(self): - if not self.generator: - self.generator = Generator(self.history_length, self.horizon) + @abstractmethod + def _get_reward(self): + raise NotImplementedError + @abstractmethod def _new_random_episode(self): - ''' - TODO: In the current setting, the selection of an episode does not follow pure uniform process. - Need to index every episode and then generate a random index rather than going on multiple levels - of selection. - ''' - self._load_gen() - self._reset_params() - message_list = [] - self.episode_number = self.episode_number + 1 - message_list.append("Starting a new episode numbered {}".format(self.episode_number)) - - block_index = np.random.randint(0, len(self.generator.price_blocks) - 1) - message_list.append("Block index selected for episode number {} is {}".format( - self.episode_number, block_index - ) - ) - - self.diffs = self.generator.diff_blocks[block_index] - self.historical_prices = self.generator.price_blocks[block_index] - - self.current = np.random.randint(self.history_length, - len(self.historical_prices) - self.horizon) - message_list.append( - "Starting index and timestamp point selected for episode number {} is {}:==:{}".format( - self.episode_number, - self.current, - self.generator.timestamp_blocks[block_index][self.current] - ) - ) - - map(self.logger.debug, message_list) - - return self.historical_prices[self.current - self.history_length:self.current] + raise NotImplementedError @abstractmethod def _reset_params(self): - pass + raise NotImplementedError @abstractmethod - def _take_action(self, action): - pass - + def _set_env_specific_params(self, **kwargs): + raise NotImplementedError + @abstractmethod - def _get_reward(self): - return 0 + def _take_action(self, action): + raise NotImplementedError - def _get_new_state(self): - return self.historical_prices[self.current] + @abstractmethod + def set_params(self, history_length, horizon, unit, **kwargs): + raise NotImplementedError def set_logger(self, custom_logger): if custom_logger: self.logger = custom_logger - - def set_params(self, history_length, horizon, unit, **kwargs): - if self.generator: - raise EnvironmentAlreadyLoaded() - - if history_length < 0 or horizon < 1 or unit < 0: - raise ValueError() - - else: - self.history_length = history_length - self.horizon = horizon - self.unit = unit #units of Bitcoin traded each time - - self._set_env_specific_params(**kwargs) - - def reset(self): - return self._new_random_episode() - - @abstractmethod - def step(self, action): - state = self._get_new_state() - self._take_action(action) - reward = self._get_reward() - return state, reward, False, None \ No newline at end of file diff --git a/gym_cryptotrading/envs/cryptoenv.py b/gym_cryptotrading/envs/cryptoenv.py new file mode 100644 index 0000000..aae9172 --- /dev/null +++ b/gym_cryptotrading/envs/cryptoenv.py @@ -0,0 +1,91 @@ +import numpy as np + +import gym +from gym import error, logger + +from abc import abstractmethod + +from gym_cryptotrading.envs.basicenv import BaseEnv + +from gym_cryptotrading.generator import Generator +from gym_cryptotrading.strings import * +from gym_cryptotrading.errors import * + +from gym_cryptotrading.spaces.action import ActionSpace +from gym_cryptotrading.spaces.observation import ObservationSpace + +class CryptoEnv(gym.Env, BaseEnv): + action_space = ActionSpace() + observation_space = ObservationSpace() + metadata = {'render.modes': []} + + def __init__(self): + super(CryptoEnv, self).__init__() + self.generator = None + + def _get_new_state(self): + return self.historical_prices[self.current] + + def _load_gen(self): + if not self.generator: + self.generator = Generator(self.history_length, self.horizon) + + def _new_random_episode(self): + ''' + TODO: In the current setting, the selection of an episode does not follow pure uniform process. + Need to index every episode and then generate a random index rather than going on multiple levels + of selection. + ''' + self._load_gen() + self._reset_params() + message_list = [] + self.episode_number = self.episode_number + 1 + message_list.append("Starting a new episode numbered {}".format(self.episode_number)) + + block_index = np.random.randint(0, len(self.generator.price_blocks) - 1) + message_list.append("Block index selected for episode number {} is {}".format( + self.episode_number, block_index + ) + ) + + self.diffs = self.generator.diff_blocks[block_index] + self.historical_prices = self.generator.price_blocks[block_index] + + self.current = np.random.randint(self.history_length, + len(self.historical_prices) - self.horizon) + message_list.append( + "Starting index and timestamp point selected for episode number {} is {}:==:{}".format( + self.episode_number, + self.current, + self.generator.timestamp_blocks[block_index][self.current] + ) + ) + + map(self.logger.debug, message_list) + + return self.historical_prices[self.current - self.history_length:self.current] + + + def _reset_params(self): + pass + + def _set_env_specific_params(self, **kwargs): + pass + + def reset(self): + return self._new_random_episode() + + def set_params(self, history_length, horizon, unit, **kwargs): + if self.generator: + raise EnvironmentAlreadyLoaded() + + if history_length < 0 or horizon < 1 or unit < 0: + raise ValueError() + + else: + self.history_length = history_length + self.horizon = horizon + self.unit = unit #units of Bitcoin traded each time + + self._set_env_specific_params(**kwargs) + \ No newline at end of file diff --git a/gym_cryptotrading/envs/realizedPnL.py b/gym_cryptotrading/envs/realizedPnL.py index f4b2b1c..6985d2f 100644 --- a/gym_cryptotrading/envs/realizedPnL.py +++ b/gym_cryptotrading/envs/realizedPnL.py @@ -3,9 +3,9 @@ from gym import error from gym_cryptotrading.strings import * -from gym_cryptotrading.envs.basicenv import BaseEnv +from gym_cryptotrading.envs.cryptoenv import CryptoEnv -class RealizedPnLEnv(BaseEnv): +class RealizedPnLEnv(CryptoEnv): def __init__(self): super(RealizedPnLEnv, self).__init__() @@ -15,13 +15,13 @@ def _reset_params(self): self.reward = 0.0 def _take_action(self, action): - if action not in BaseEnv.action_space.lookup.keys(): + if action not in CryptoEnv.action_space.lookup.keys(): raise error.InvalidAction() else: - if BaseEnv.action_space.lookup[action] is LONG: + if CryptoEnv.action_space.lookup[action] is LONG: self.long = self.long + 1 - elif BaseEnv.action_space.lookup[action] is SHORT: + elif CryptoEnv.action_space.lookup[action] is SHORT: self.short = self.short + 1 def _get_reward(self): @@ -41,7 +41,7 @@ def step(self, action): reward = self._get_reward() message = "Timestep {}:==: Action: {} ; Reward: {}".format( - self.timesteps, BaseEnv.action_space.lookup[action], reward + self.timesteps, CryptoEnv.action_space.lookup[action], reward ) self.logger.debug(message) diff --git a/gym_cryptotrading/envs/unrealizedPnL.py b/gym_cryptotrading/envs/unrealizedPnL.py index 5f43c09..7d7393c 100644 --- a/gym_cryptotrading/envs/unrealizedPnL.py +++ b/gym_cryptotrading/envs/unrealizedPnL.py @@ -3,9 +3,9 @@ from gym import error from gym_cryptotrading.strings import * -from gym_cryptotrading.envs.basicenv import BaseEnv +from gym_cryptotrading.envs.cryptoenv import CryptoEnv -class UnRealizedPnLEnv(BaseEnv): +class UnRealizedPnLEnv(CryptoEnv): def __init__(self): super(UnRealizedPnLEnv, self).__init__() @@ -14,13 +14,13 @@ def _reset_params(self): self.timesteps = 0 def _take_action(self, action): - if action not in BaseEnv.action_space.lookup.keys(): + if action not in CryptoEnv.action_space.lookup.keys(): raise error.InvalidAction() else: - if BaseEnv.action_space.lookup[action] is LONG: + if CryptoEnv.action_space.lookup[action] is LONG: self.long = self.long + 1 - elif BaseEnv.action_space.lookup[action] is SHORT: + elif CryptoEnv.action_space.lookup[action] is SHORT: self.short = self.short + 1 def _get_reward(self): @@ -35,7 +35,7 @@ def step(self, action): reward = self._get_reward() message = "Timestep {}:==: Action: {} ; Reward: {}".format( - self.timesteps, BaseEnv.action_space.lookup[action], reward + self.timesteps, CryptoEnv.action_space.lookup[action], reward ) self.logger.debug(message) diff --git a/gym_cryptotrading/envs/weightedPnL.py b/gym_cryptotrading/envs/weightedPnL.py index 5e04af2..be8f987 100644 --- a/gym_cryptotrading/envs/weightedPnL.py +++ b/gym_cryptotrading/envs/weightedPnL.py @@ -5,7 +5,7 @@ from gym import error from gym_cryptotrading.strings import * -from gym_cryptotrading.envs.basicenv import BaseEnv +from gym_cryptotrading.envs.cryptoenv import CryptoEnv class ExponentiallyWeightedReward: def __init__(self, lag, decay_rate): @@ -36,7 +36,7 @@ def insert(self, reward): def reward(self): return self.sum / self.denominator -class WeightedPnLEnv(BaseEnv): +class WeightedPnLEnv(CryptoEnv): def __init__(self): super(WeightedPnLEnv, self).__init__() @@ -63,13 +63,13 @@ def _reset_params(self): self.reward = ExponentiallyWeightedReward(self.lag, self.decay_rate) def _take_action(self, action): - if action not in BaseEnv.action_space.lookup.keys(): + if action not in CryptoEnv.action_space.lookup.keys(): raise error.InvalidAction() else: - if BaseEnv.action_space.lookup[action] is LONG: + if CryptoEnv.action_space.lookup[action] is LONG: self.long = self.long + 1 - elif BaseEnv.action_space.lookup[action] is SHORT: + elif CryptoEnv.action_space.lookup[action] is SHORT: self.short = self.short + 1 def _get_reward(self): @@ -86,7 +86,7 @@ def step(self, action): reward = self._get_reward() message = "Timestep {}:==: Action: {} ; Reward: {}".format( - self.timesteps, BaseEnv.action_space.lookup[action], reward + self.timesteps, CryptoEnv.action_space.lookup[action], reward ) self.logger.debug(message)