Skip to content

Commit

Permalink
improve reading speed of FromEpisodeData by not reading everything, s…
Browse files Browse the repository at this point in the history
…ee issue Grid2op#659

Signed-off-by: DONNOT Benjamin <[email protected]>
  • Loading branch information
BDonnot committed Nov 8, 2024
1 parent f259521 commit b658793
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 58 deletions.
5 changes: 2 additions & 3 deletions grid2op/Chronics/fromOneEpisodeData.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,19 @@ def __init__(
if self.path is not None:
# logger: this has no impact
pass

if isinstance(ep_data, EpisodeData):
self._episode_data = ep_data
elif isinstance(ep_data, (str, Path)):
try:
self._episode_data = EpisodeData.from_disk(*os.path.split(ep_data))
self._episode_data = EpisodeData.from_disk(*os.path.split(ep_data), _only_act_obs=True)
except Exception as exc_:
raise ChronicsError("Impossible to build the FromOneEpisodeData with the `ep_data` provided.") from exc_
elif isinstance(ep_data, (tuple, list)):
if len(ep_data) != 2:
raise ChronicsError("When you provide a tuple, or a list, FromOneEpisodeData can only be used if this list has length 2. "
f"Length {len(ep_data)} found.")
try:
self._episode_data = EpisodeData.from_disk(*ep_data)
self._episode_data = EpisodeData.from_disk(*ep_data, _only_act_obs=True)
except Exception as exc_:
raise ChronicsError("Impossible to build the FromOneEpisodeData with the `ep_data` provided.") from exc_
else:
Expand Down
142 changes: 89 additions & 53 deletions grid2op/Episode/EpisodeData.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,24 +204,33 @@ def __init__(
observations, observation_space, "observations", init_me=_init_collections
)

self.env_actions = CollectionWrapper(
env_actions,
helper_action_env,
"env_actions",
check_legit=False,
init_me=_init_collections,
)
if env_actions is not None:
self.env_actions = CollectionWrapper(
env_actions,
helper_action_env,
"env_actions",
check_legit=False,
init_me=_init_collections,
)
else:
self.env_actions = None

self.attacks = CollectionWrapper(
attack, attack_space, "attacks", init_me=_init_collections
)
if attack is not None:
self.attacks = CollectionWrapper(
attack, attack_space, "attacks", init_me=_init_collections
)
else:
self.attacks = None

self.meta = meta
# gives a unique game over for everyone
# TODO this needs testing!
action_go = self.actions._game_over
obs_go = self.observations._game_over
env_go = self.env_actions._game_over
if self.env_actions is not None:
env_go = self.env_actions._game_over
else:
env_go = None
# raise RuntimeError("Add the attaks game over too !")
real_go = action_go
if self.meta is not None:
Expand All @@ -247,7 +256,8 @@ def __init__(
# there is a real game over, i assign the proper value for each collection
self.actions._game_over = real_go
self.observations._game_over = real_go + 1
self.env_actions._game_over = real_go
if self.env_actions is not None:
self.env_actions._game_over = real_go

self.other_rewards = other_rewards
self.observation_space = observation_space
Expand Down Expand Up @@ -401,12 +411,14 @@ def reboot(self):
"""
self.actions.reboot()
self.observations.reboot()
self.env_actions.reboot()
if self.env_actions is not None:
self.env_actions.reboot()

def go_to(self, index):
self.actions.go_to(index)
self.observations.go_to(index + 1)
self.env_actions.go_to(index)
if self.env_actions is not None:
self.env_actions.go_to(index)

def get_actions(self):
return self.actions.collection
Expand All @@ -415,13 +427,17 @@ def get_observations(self):
return self.observations.collection

def __len__(self):
tmp = int(self.meta["chronics_max_timestep"])
if tmp > 0:
return min(tmp, len(self.observations))
if self.meta is not None:
tmp = int(self.meta["chronics_max_timestep"])
if tmp > 0:
return min(tmp, len(self.observations))
return len(self.observations)

@classmethod
def from_disk(cls, agent_path, name="1"):
def from_disk(cls,
agent_path: os.PathLike,
name:str="1",
_only_act_obs :bool =False):
"""
This function allows you to reload an episode stored using the runner.
Expand All @@ -434,6 +450,9 @@ def from_disk(cls, agent_path, name="1"):
name: ``str``
The name of the episode you want to reload.
_only_act_obs: bool
Load only part of the episode data
Returns
-------
Expand All @@ -448,57 +467,74 @@ def from_disk(cls, agent_path, name="1"):
episode_path = os.path.abspath(os.path.join(agent_path, name))

try:
with open(os.path.join(episode_path, EpisodeData.PARAMS)) as f:
_parameters = json.load(fp=f)
with open(os.path.join(episode_path, EpisodeData.META)) as f:
episode_meta = json.load(fp=f)
with open(os.path.join(episode_path, EpisodeData.TIMES)) as f:
episode_times = json.load(fp=f)
with open(os.path.join(episode_path, EpisodeData.OTHER_REWARDS)) as f:
other_rewards = json.load(fp=f)

times = np.load(os.path.join(episode_path, EpisodeData.AG_EXEC_TIMES))[
"data"
]
path_legal_ambiguous = os.path.join(episode_path, cls.LEGAL_AMBIGUOUS)
if _only_act_obs:
_parameters = None
episode_meta = None
episode_times = None
other_rewards = None
times = None
env_actions = None
disc_lines = None
attack = None
rewards = None
has_legal_ambiguous = False
legal = None
ambiguous = None
else:
with open(os.path.join(episode_path, cls.PARAMS)) as f:
_parameters = json.load(fp=f)
with open(os.path.join(episode_path, cls.META)) as f:
episode_meta = json.load(fp=f)
with open(os.path.join(episode_path, cls.TIMES)) as f:
episode_times = json.load(fp=f)
with open(os.path.join(episode_path, cls.OTHER_REWARDS)) as f:
other_rewards = json.load(fp=f)

times = np.load(os.path.join(episode_path, cls.AG_EXEC_TIMES))[
"data"
]
env_actions = np.load(os.path.join(episode_path, cls.ENV_ACTIONS_FILE))[
"data"
]
disc_lines = np.load(
os.path.join(episode_path, cls.LINES_FAILURES)
)["data"]
rewards = np.load(os.path.join(episode_path, cls.REWARDS))["data"]
has_legal_ambiguous = False
if os.path.exists(path_legal_ambiguous):
legal_ambiguous = np.load(path_legal_ambiguous)["data"]
legal = copy.deepcopy(legal_ambiguous[:, 0])
ambiguous = copy.deepcopy(legal_ambiguous[:, 1])
has_legal_ambiguous = True
else:
legal = None
ambiguous = None

actions = np.load(os.path.join(episode_path, EpisodeData.ACTIONS_FILE))["data"]
env_actions = np.load(os.path.join(episode_path, EpisodeData.ENV_ACTIONS_FILE))[
"data"
]
observations = np.load(
os.path.join(episode_path, EpisodeData.OBSERVATIONS_FILE)
)["data"]
disc_lines = np.load(
os.path.join(episode_path, EpisodeData.LINES_FAILURES)
)["data"]
attack = np.load(os.path.join(episode_path, EpisodeData.ATTACK))["data"]
rewards = np.load(os.path.join(episode_path, EpisodeData.REWARDS))["data"]

path_legal_ambiguous = os.path.join(episode_path, EpisodeData.LEGAL_AMBIGUOUS)
has_legal_ambiguous = False
if os.path.exists(path_legal_ambiguous):
legal_ambiguous = np.load(path_legal_ambiguous)["data"]
legal = copy.deepcopy(legal_ambiguous[:, 0])
ambiguous = copy.deepcopy(legal_ambiguous[:, 1])
has_legal_ambiguous = True
else:
legal = None
ambiguous = None

except FileNotFoundError as ex:
raise Grid2OpException(f"EpisodeData file not found \n {str(ex)}")
except FileNotFoundError as exc_:
raise Grid2OpException(f"EpisodeData failed to load the file. Some data are not found.") from exc_

observation_space = ObservationSpace.from_dict(
os.path.join(agent_path, EpisodeData.OBS_SPACE)
)
action_space = ActionSpace.from_dict(
os.path.join(agent_path, EpisodeData.ACTION_SPACE)
)
helper_action_env = ActionSpace.from_dict(
os.path.join(agent_path, EpisodeData.ENV_MODIF_SPACE)
)
attack_space = ActionSpace.from_dict(
os.path.join(agent_path, EpisodeData.ATTACK_SPACE)
)
if _only_act_obs:
helper_action_env = None
else:
helper_action_env = ActionSpace.from_dict(
os.path.join(agent_path, EpisodeData.ENV_MODIF_SPACE)
)
if observation_space.glop_version != grid2op.__version__:
warnings.warn(
'You are using a "grid2op compatibility" feature (the data you saved '
Expand Down
4 changes: 2 additions & 2 deletions grid2op/tests/test_env_from_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def test_assert_warnings(self):
)


class TestTSFromMultieEpisode(unittest.TestCase):
class TestTSFromMultiEpisode(unittest.TestCase):
def setUp(self) -> None:
self.env_name = "l2rpn_case14_sandbox"
with warnings.catch_warnings():
Expand Down Expand Up @@ -613,7 +613,7 @@ def test_basic(self):
assert env.chronics_handler.get_id() == f"{path_}@1", f"{env.chronics_handler.get_id()} vs {path_}@1"


class TestTSFromMultieEpisodeWithCache(TestTSFromMultieEpisode):
class TestTSFromMultiEpisodeWithCache(TestTSFromMultiEpisode):
def do_i_cache(self):
return True

Expand Down

0 comments on commit b658793

Please sign in to comment.