-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from samre12/dev
Refactored code for a base cryptotrading env
- Loading branch information
Showing
6 changed files
with
132 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
import logging | ||
from gym.envs.registration import register | ||
|
||
register( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,114 +1,51 @@ | ||
import numpy as np | ||
from abc import ABCMeta, abstractmethod | ||
|
||
import gym | ||
from gym import error, logger | ||
from gym import logger | ||
|
||
from abc import abstractmethod | ||
class BaseEnv: | ||
''' | ||
Abstract Base Class for CryptoTrading Environments | ||
''' | ||
|
||
from gym_cryptotrading.generator import Generator | ||
from gym_cryptotrading.strings import * | ||
from gym_cryptotrading.errors import * | ||
|
||
from gym_cryptotrading.spaces.action import ActionSpace | ||
from gym_cryptotrading.spaces.observation import ObservationSpace | ||
|
||
class BaseEnv(gym.Env): | ||
action_space = ActionSpace() | ||
observation_space = ObservationSpace() | ||
metadata = {'render.modes': []} | ||
__metaclass__ = ABCMeta | ||
|
||
def __init__(self): | ||
self.episode_number = 0 | ||
self.generator = None | ||
self.logger = logger | ||
|
||
self.history_length = 100 | ||
self.horizon = 5 | ||
self.unit = 5e-4 | ||
|
||
@abstractmethod | ||
def _set_env_specific_params(self, **kwargs): | ||
pass | ||
def _get_new_state(self): | ||
raise NotImplementedError | ||
|
||
def _load_gen(self): | ||
if not self.generator: | ||
self.generator = Generator(self.history_length, self.horizon) | ||
@abstractmethod | ||
def _get_reward(self): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _new_random_episode(self): | ||
''' | ||
TODO: In the current setting, the selection of an episode does not follow pure uniform process. | ||
Need to index every episode and then generate a random index rather than going on multiple levels | ||
of selection. | ||
''' | ||
self._load_gen() | ||
self._reset_params() | ||
message_list = [] | ||
self.episode_number = self.episode_number + 1 | ||
message_list.append("Starting a new episode numbered {}".format(self.episode_number)) | ||
|
||
block_index = np.random.randint(0, len(self.generator.price_blocks) - 1) | ||
message_list.append("Block index selected for episode number {} is {}".format( | ||
self.episode_number, block_index | ||
) | ||
) | ||
|
||
self.diffs = self.generator.diff_blocks[block_index] | ||
self.historical_prices = self.generator.price_blocks[block_index] | ||
|
||
self.current = np.random.randint(self.history_length, | ||
len(self.historical_prices) - self.horizon) | ||
message_list.append( | ||
"Starting index and timestamp point selected for episode number {} is {}:==:{}".format( | ||
self.episode_number, | ||
self.current, | ||
self.generator.timestamp_blocks[block_index][self.current] | ||
) | ||
) | ||
|
||
map(self.logger.debug, message_list) | ||
|
||
return self.historical_prices[self.current - self.history_length:self.current] | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _reset_params(self): | ||
pass | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _take_action(self, action): | ||
pass | ||
def _set_env_specific_params(self, **kwargs): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _get_reward(self): | ||
return 0 | ||
def _take_action(self, action): | ||
raise NotImplementedError | ||
|
||
def _get_new_state(self): | ||
return self.historical_prices[self.current] | ||
@abstractmethod | ||
def set_params(self, history_length, horizon, unit, **kwargs): | ||
raise NotImplementedError | ||
|
||
def set_logger(self, custom_logger): | ||
if custom_logger: | ||
self.logger = custom_logger | ||
|
||
def set_params(self, history_length, horizon, unit, **kwargs): | ||
if self.generator: | ||
raise EnvironmentAlreadyLoaded() | ||
|
||
if history_length < 0 or horizon < 1 or unit < 0: | ||
raise ValueError() | ||
|
||
else: | ||
self.history_length = history_length | ||
self.horizon = horizon | ||
self.unit = unit #units of Bitcoin traded each time | ||
|
||
self._set_env_specific_params(**kwargs) | ||
|
||
def reset(self): | ||
return self._new_random_episode() | ||
|
||
@abstractmethod | ||
def step(self, action): | ||
state = self._get_new_state() | ||
self._take_action(action) | ||
reward = self._get_reward() | ||
return state, reward, False, None | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import numpy as np | ||
|
||
import gym | ||
from gym import error, logger | ||
|
||
from abc import abstractmethod | ||
|
||
from gym_cryptotrading.envs.basicenv import BaseEnv | ||
|
||
from gym_cryptotrading.generator import Generator | ||
from gym_cryptotrading.strings import * | ||
from gym_cryptotrading.errors import * | ||
|
||
from gym_cryptotrading.spaces.action import ActionSpace | ||
from gym_cryptotrading.spaces.observation import ObservationSpace | ||
|
||
class CryptoEnv(gym.Env, BaseEnv): | ||
action_space = ActionSpace() | ||
observation_space = ObservationSpace() | ||
metadata = {'render.modes': []} | ||
|
||
def __init__(self): | ||
super(CryptoEnv, self).__init__() | ||
self.generator = None | ||
|
||
def _get_new_state(self): | ||
return self.historical_prices[self.current] | ||
|
||
def _load_gen(self): | ||
if not self.generator: | ||
self.generator = Generator(self.history_length, self.horizon) | ||
|
||
def _new_random_episode(self): | ||
''' | ||
TODO: In the current setting, the selection of an episode does not follow pure uniform process. | ||
Need to index every episode and then generate a random index rather than going on multiple levels | ||
of selection. | ||
''' | ||
self._load_gen() | ||
self._reset_params() | ||
message_list = [] | ||
self.episode_number = self.episode_number + 1 | ||
message_list.append("Starting a new episode numbered {}".format(self.episode_number)) | ||
|
||
block_index = np.random.randint(0, len(self.generator.price_blocks) - 1) | ||
message_list.append("Block index selected for episode number {} is {}".format( | ||
self.episode_number, block_index | ||
) | ||
) | ||
|
||
self.diffs = self.generator.diff_blocks[block_index] | ||
self.historical_prices = self.generator.price_blocks[block_index] | ||
|
||
self.current = np.random.randint(self.history_length, | ||
len(self.historical_prices) - self.horizon) | ||
message_list.append( | ||
"Starting index and timestamp point selected for episode number {} is {}:==:{}".format( | ||
self.episode_number, | ||
self.current, | ||
self.generator.timestamp_blocks[block_index][self.current] | ||
) | ||
) | ||
|
||
map(self.logger.debug, message_list) | ||
|
||
return self.historical_prices[self.current - self.history_length:self.current] | ||
|
||
|
||
def _reset_params(self): | ||
pass | ||
|
||
def _set_env_specific_params(self, **kwargs): | ||
pass | ||
|
||
def reset(self): | ||
return self._new_random_episode() | ||
|
||
def set_params(self, history_length, horizon, unit, **kwargs): | ||
if self.generator: | ||
raise EnvironmentAlreadyLoaded() | ||
|
||
if history_length < 0 or horizon < 1 or unit < 0: | ||
raise ValueError() | ||
|
||
else: | ||
self.history_length = history_length | ||
self.horizon = horizon | ||
self.unit = unit #units of Bitcoin traded each time | ||
|
||
self._set_env_specific_params(**kwargs) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters