From b65879305af57aa6a3138f6d19b2c43616c87ddf Mon Sep 17 00:00:00 2001 From: DONNOT Benjamin Date: Fri, 8 Nov 2024 09:05:04 +0100 Subject: [PATCH] improve reading speed of FromEpisodeData by not reading everything, see issue #659 Signed-off-by: DONNOT Benjamin --- grid2op/Chronics/fromOneEpisodeData.py | 5 +- grid2op/Episode/EpisodeData.py | 142 ++++++++++++++++--------- grid2op/tests/test_env_from_episode.py | 4 +- 3 files changed, 93 insertions(+), 58 deletions(-) diff --git a/grid2op/Chronics/fromOneEpisodeData.py b/grid2op/Chronics/fromOneEpisodeData.py index 9dbe959e..bd6c85b2 100644 --- a/grid2op/Chronics/fromOneEpisodeData.py +++ b/grid2op/Chronics/fromOneEpisodeData.py @@ -177,12 +177,11 @@ def __init__( if self.path is not None: # logger: this has no impact pass - if isinstance(ep_data, EpisodeData): self._episode_data = ep_data elif isinstance(ep_data, (str, Path)): try: - self._episode_data = EpisodeData.from_disk(*os.path.split(ep_data)) + self._episode_data = EpisodeData.from_disk(*os.path.split(ep_data), _only_act_obs=True) except Exception as exc_: raise ChronicsError("Impossible to build the FromOneEpisodeData with the `ep_data` provided.") from exc_ elif isinstance(ep_data, (tuple, list)): @@ -190,7 +189,7 @@ def __init__( raise ChronicsError("When you provide a tuple, or a list, FromOneEpisodeData can only be used if this list has length 2. " f"Length {len(ep_data)} found.") try: - self._episode_data = EpisodeData.from_disk(*ep_data) + self._episode_data = EpisodeData.from_disk(*ep_data, _only_act_obs=True) except Exception as exc_: raise ChronicsError("Impossible to build the FromOneEpisodeData with the `ep_data` provided.") from exc_ else: diff --git a/grid2op/Episode/EpisodeData.py b/grid2op/Episode/EpisodeData.py index 1925fd7b..6e5b4e7f 100644 --- a/grid2op/Episode/EpisodeData.py +++ b/grid2op/Episode/EpisodeData.py @@ -204,24 +204,33 @@ def __init__( observations, observation_space, "observations", init_me=_init_collections ) - self.env_actions = CollectionWrapper( - env_actions, - helper_action_env, - "env_actions", - check_legit=False, - init_me=_init_collections, - ) + if env_actions is not None: + self.env_actions = CollectionWrapper( + env_actions, + helper_action_env, + "env_actions", + check_legit=False, + init_me=_init_collections, + ) + else: + self.env_actions = None - self.attacks = CollectionWrapper( - attack, attack_space, "attacks", init_me=_init_collections - ) + if attack is not None: + self.attacks = CollectionWrapper( + attack, attack_space, "attacks", init_me=_init_collections + ) + else: + self.attacks = None self.meta = meta # gives a unique game over for everyone # TODO this needs testing! action_go = self.actions._game_over obs_go = self.observations._game_over - env_go = self.env_actions._game_over + if self.env_actions is not None: + env_go = self.env_actions._game_over + else: + env_go = None # raise RuntimeError("Add the attaks game over too !") real_go = action_go if self.meta is not None: @@ -247,7 +256,8 @@ def __init__( # there is a real game over, i assign the proper value for each collection self.actions._game_over = real_go self.observations._game_over = real_go + 1 - self.env_actions._game_over = real_go + if self.env_actions is not None: + self.env_actions._game_over = real_go self.other_rewards = other_rewards self.observation_space = observation_space @@ -401,12 +411,14 @@ def reboot(self): """ self.actions.reboot() self.observations.reboot() - self.env_actions.reboot() + if self.env_actions is not None: + self.env_actions.reboot() def go_to(self, index): self.actions.go_to(index) self.observations.go_to(index + 1) - self.env_actions.go_to(index) + if self.env_actions is not None: + self.env_actions.go_to(index) def get_actions(self): return self.actions.collection @@ -415,13 +427,17 @@ def get_observations(self): return self.observations.collection def __len__(self): - tmp = int(self.meta["chronics_max_timestep"]) - if tmp > 0: - return min(tmp, len(self.observations)) + if self.meta is not None: + tmp = int(self.meta["chronics_max_timestep"]) + if tmp > 0: + return min(tmp, len(self.observations)) return len(self.observations) @classmethod - def from_disk(cls, agent_path, name="1"): + def from_disk(cls, + agent_path: os.PathLike, + name:str="1", + _only_act_obs :bool =False): """ This function allows you to reload an episode stored using the runner. @@ -434,6 +450,9 @@ def from_disk(cls, agent_path, name="1"): name: ``str`` The name of the episode you want to reload. + + _only_act_obs: bool + Load only part of the episode data Returns ------- @@ -448,44 +467,58 @@ def from_disk(cls, agent_path, name="1"): episode_path = os.path.abspath(os.path.join(agent_path, name)) try: - with open(os.path.join(episode_path, EpisodeData.PARAMS)) as f: - _parameters = json.load(fp=f) - with open(os.path.join(episode_path, EpisodeData.META)) as f: - episode_meta = json.load(fp=f) - with open(os.path.join(episode_path, EpisodeData.TIMES)) as f: - episode_times = json.load(fp=f) - with open(os.path.join(episode_path, EpisodeData.OTHER_REWARDS)) as f: - other_rewards = json.load(fp=f) - - times = np.load(os.path.join(episode_path, EpisodeData.AG_EXEC_TIMES))[ - "data" - ] + path_legal_ambiguous = os.path.join(episode_path, cls.LEGAL_AMBIGUOUS) + if _only_act_obs: + _parameters = None + episode_meta = None + episode_times = None + other_rewards = None + times = None + env_actions = None + disc_lines = None + attack = None + rewards = None + has_legal_ambiguous = False + legal = None + ambiguous = None + else: + with open(os.path.join(episode_path, cls.PARAMS)) as f: + _parameters = json.load(fp=f) + with open(os.path.join(episode_path, cls.META)) as f: + episode_meta = json.load(fp=f) + with open(os.path.join(episode_path, cls.TIMES)) as f: + episode_times = json.load(fp=f) + with open(os.path.join(episode_path, cls.OTHER_REWARDS)) as f: + other_rewards = json.load(fp=f) + + times = np.load(os.path.join(episode_path, cls.AG_EXEC_TIMES))[ + "data" + ] + env_actions = np.load(os.path.join(episode_path, cls.ENV_ACTIONS_FILE))[ + "data" + ] + disc_lines = np.load( + os.path.join(episode_path, cls.LINES_FAILURES) + )["data"] + rewards = np.load(os.path.join(episode_path, cls.REWARDS))["data"] + has_legal_ambiguous = False + if os.path.exists(path_legal_ambiguous): + legal_ambiguous = np.load(path_legal_ambiguous)["data"] + legal = copy.deepcopy(legal_ambiguous[:, 0]) + ambiguous = copy.deepcopy(legal_ambiguous[:, 1]) + has_legal_ambiguous = True + else: + legal = None + ambiguous = None + actions = np.load(os.path.join(episode_path, EpisodeData.ACTIONS_FILE))["data"] - env_actions = np.load(os.path.join(episode_path, EpisodeData.ENV_ACTIONS_FILE))[ - "data" - ] observations = np.load( os.path.join(episode_path, EpisodeData.OBSERVATIONS_FILE) )["data"] - disc_lines = np.load( - os.path.join(episode_path, EpisodeData.LINES_FAILURES) - )["data"] attack = np.load(os.path.join(episode_path, EpisodeData.ATTACK))["data"] - rewards = np.load(os.path.join(episode_path, EpisodeData.REWARDS))["data"] - - path_legal_ambiguous = os.path.join(episode_path, EpisodeData.LEGAL_AMBIGUOUS) - has_legal_ambiguous = False - if os.path.exists(path_legal_ambiguous): - legal_ambiguous = np.load(path_legal_ambiguous)["data"] - legal = copy.deepcopy(legal_ambiguous[:, 0]) - ambiguous = copy.deepcopy(legal_ambiguous[:, 1]) - has_legal_ambiguous = True - else: - legal = None - ambiguous = None - except FileNotFoundError as ex: - raise Grid2OpException(f"EpisodeData file not found \n {str(ex)}") + except FileNotFoundError as exc_: + raise Grid2OpException(f"EpisodeData failed to load the file. Some data are not found.") from exc_ observation_space = ObservationSpace.from_dict( os.path.join(agent_path, EpisodeData.OBS_SPACE) @@ -493,12 +526,15 @@ def from_disk(cls, agent_path, name="1"): action_space = ActionSpace.from_dict( os.path.join(agent_path, EpisodeData.ACTION_SPACE) ) - helper_action_env = ActionSpace.from_dict( - os.path.join(agent_path, EpisodeData.ENV_MODIF_SPACE) - ) attack_space = ActionSpace.from_dict( os.path.join(agent_path, EpisodeData.ATTACK_SPACE) ) + if _only_act_obs: + helper_action_env = None + else: + helper_action_env = ActionSpace.from_dict( + os.path.join(agent_path, EpisodeData.ENV_MODIF_SPACE) + ) if observation_space.glop_version != grid2op.__version__: warnings.warn( 'You are using a "grid2op compatibility" feature (the data you saved ' diff --git a/grid2op/tests/test_env_from_episode.py b/grid2op/tests/test_env_from_episode.py index b55e53ed..72681d7b 100644 --- a/grid2op/tests/test_env_from_episode.py +++ b/grid2op/tests/test_env_from_episode.py @@ -531,7 +531,7 @@ def test_assert_warnings(self): ) -class TestTSFromMultieEpisode(unittest.TestCase): +class TestTSFromMultiEpisode(unittest.TestCase): def setUp(self) -> None: self.env_name = "l2rpn_case14_sandbox" with warnings.catch_warnings(): @@ -613,7 +613,7 @@ def test_basic(self): assert env.chronics_handler.get_id() == f"{path_}@1", f"{env.chronics_handler.get_id()} vs {path_}@1" -class TestTSFromMultieEpisodeWithCache(TestTSFromMultieEpisode): +class TestTSFromMultiEpisodeWithCache(TestTSFromMultiEpisode): def do_i_cache(self): return True