Skip to content

Commit

Permalink
fix issue Grid2op#667
Browse files Browse the repository at this point in the history
Signed-off-by: DONNOT Benjamin <[email protected]>
  • Loading branch information
BDonnot committed Nov 28, 2024
1 parent 2483987 commit 85d2e3e
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 76 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ Native multi agents support:
- [FIXED] the `obs.get_forecast_env` : in some cases the resulting first
observation (obtained from `for_env.reset()`) did not have the correct
topology.
- [FIXED] issue https://github.com/Grid2op/grid2op/issues/665 (`obs.reset()`
was not correctly implemented: some attributes were forgotten)
- [FIXED] issue https://github.com/Grid2op/grid2op/issues/667 (`act.as_serializable_dict()`
was not correctly implemented AND the `_aux_affect_object_int` and `_aux_affect_object_float`
have been also fixed - weird behaviour when you give them a list with the exact length of the
object you tried to modified (for example a list with a size of `n_load` that affected the loads))
- [ADDED] possibility to set the "thermal limits" when calling `env.reset(..., options={"thermal limit": xxx})`
- [ADDED] possibility to retrieve some structural information about elements with
with `gridobj.get_line_info(...)`, `gridobj.get_load_info(...)`, `gridobj.get_gen_info(...)`
Expand All @@ -138,6 +144,10 @@ Native multi agents support:
does not have shunt information but there are not shunts on the grid.
- [IMPROVED] consistency of `MultiMixEnv` in case of automatic_classes (only one
class is generated for all mixes)
- [IMPROVED] the `act.as_serializable_dict()` to be more 'backend agnostic'as
it nows tries to use the name of the elements in the json output
- [IMPROVED] the way shunt data are digested in the `BaseAction` class (it is now
possible to use the same things as for the other types of element)

[1.10.4] - 2024-10-15
-------------------------
Expand Down
176 changes: 100 additions & 76 deletions grid2op/Action/baseAction.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,13 +592,13 @@ def __deepcopy__(self, memodict={}) -> "BaseAction":

return res

def _aux_serialize_add_key_change(self, attr_nm, dict_key, res):
tmp_ = [int(id_) for id_, val in enumerate(getattr(self, attr_nm)) if val]
def _aux_serialize_add_key_change(self, attr_nm, dict_key, res, vect_id_to_name):
tmp_ = [str(vect_id_to_name[id_]) for id_, val in enumerate(getattr(self, attr_nm)) if val]
if tmp_:
res[dict_key] = tmp_

def _aux_serialize_add_key_set(self, attr_nm, dict_key, res):
tmp_ = [(int(id_), int(val)) for id_, val in enumerate(getattr(self, attr_nm)) if np.abs(val) >= 1e-7]
def _aux_serialize_add_key_set(self, attr_nm, dict_key, res, vect_id_to_name):
tmp_ = [(str(vect_id_to_name[id_]), int(val)) for id_, val in enumerate(getattr(self, attr_nm)) if np.abs(val) >= 1e-7]
if tmp_:
res[dict_key] = tmp_

Expand Down Expand Up @@ -651,37 +651,37 @@ def as_serializable_dict(self) -> dict:

if self._modif_change_bus:
res["change_bus"] = {}
self._aux_serialize_add_key_change("load_change_bus", "loads_id", res["change_bus"])
self._aux_serialize_add_key_change("gen_change_bus", "generators_id", res["change_bus"])
self._aux_serialize_add_key_change("line_or_change_bus", "lines_or_id", res["change_bus"])
self._aux_serialize_add_key_change("line_ex_change_bus", "lines_ex_id", res["change_bus"])
self._aux_serialize_add_key_change("load_change_bus", "loads_id", res["change_bus"], cls.name_load)
self._aux_serialize_add_key_change("gen_change_bus", "generators_id", res["change_bus"], cls.name_gen)
self._aux_serialize_add_key_change("line_or_change_bus", "lines_or_id", res["change_bus"], cls.name_line)
self._aux_serialize_add_key_change("line_ex_change_bus", "lines_ex_id", res["change_bus"], cls.name_line)
if hasattr(cls, "n_storage") and cls.n_storage:
self._aux_serialize_add_key_change("storage_change_bus", "storages_id", res["change_bus"])
self._aux_serialize_add_key_change("storage_change_bus", "storages_id", res["change_bus"], cls.name_storage)
if not res["change_bus"]:
del res["change_bus"]

if self._modif_change_status:
res["change_line_status"] = [
int(id_) for id_, val in enumerate(self._switch_line_status) if val
str(cls.name_line[id_]) for id_, val in enumerate(self._switch_line_status) if val
]
if not res["change_line_status"]:
del res["change_line_status"]

# int elements
if self._modif_set_bus:
res["set_bus"] = {}
self._aux_serialize_add_key_set("load_set_bus", "loads_id", res["set_bus"])
self._aux_serialize_add_key_set("gen_set_bus", "generators_id", res["set_bus"])
self._aux_serialize_add_key_set("line_or_set_bus", "lines_or_id", res["set_bus"])
self._aux_serialize_add_key_set("line_ex_set_bus", "lines_ex_id", res["set_bus"])
self._aux_serialize_add_key_set("load_set_bus", "loads_id", res["set_bus"], cls.name_load)
self._aux_serialize_add_key_set("gen_set_bus", "generators_id", res["set_bus"], cls.name_gen)
self._aux_serialize_add_key_set("line_or_set_bus", "lines_or_id", res["set_bus"], cls.name_line)
self._aux_serialize_add_key_set("line_ex_set_bus", "lines_ex_id", res["set_bus"], cls.name_line)
if hasattr(cls, "n_storage") and cls.n_storage:
self._aux_serialize_add_key_set("storage_set_bus", "storages_id", res["set_bus"])
self._aux_serialize_add_key_set("storage_set_bus", "storages_id", res["set_bus"], cls.name_storage)
if not res["set_bus"]:
del res["set_bus"]

if self._modif_set_status:
res["set_line_status"] = [
(int(id_), int(val))
(str(cls.name_line[id_]), int(val))
for id_, val in enumerate(self._set_line_status)
if val != 0
]
Expand All @@ -691,7 +691,7 @@ def as_serializable_dict(self) -> dict:
# float elements
if self._modif_redispatch:
res["redispatch"] = [
(int(id_), float(val))
(str(cls.name_gen[id_]), float(val))
for id_, val in enumerate(self._redispatch)
if np.abs(val) >= 1e-7
]
Expand All @@ -700,7 +700,7 @@ def as_serializable_dict(self) -> dict:

if self._modif_storage:
res["set_storage"] = [
(int(id_), float(val))
(str(cls.name_storage[id_]), float(val))
for id_, val in enumerate(self._storage_power)
if np.abs(val) >= 1e-7
]
Expand All @@ -709,7 +709,7 @@ def as_serializable_dict(self) -> dict:

if self._modif_curtailment:
res["curtail"] = [
(int(id_), float(val))
(str(cls.name_gen[id_]), float(val))
for id_, val in enumerate(self._curtail)
if np.abs(val + 1.) >= 1e-7
]
Expand All @@ -719,9 +719,10 @@ def as_serializable_dict(self) -> dict:
# more advanced options
if self._modif_inj:
res["injection"] = {}
for ky in ["prod_p", "prod_v", "load_p", "load_q"]:
for ky, vect_nm in zip(["prod_p", "prod_v", "load_p", "load_q"],
[cls.name_gen, cls.name_gen, cls.name_load, cls.name_load]):
if ky in self._dict_inj:
res["injection"][ky] = [float(val) for val in self._dict_inj[ky]]
res["injection"][ky] = {str(vect_nm[i]): float(val) for i, val in enumerate(self._dict_inj[ky])}
if not res["injection"]:
del res["injection"]

Expand Down Expand Up @@ -1860,60 +1861,82 @@ def __call__(self) -> Tuple[dict, np.ndarray, np.ndarray, np.ndarray, np.ndarray

def _digest_shunt(self, dict_):
cls = type(self)
if "shunt" in dict_:
ddict_ = dict_["shunt"]

key_shunt_reco = {"set_bus", "shunt_p", "shunt_q", "shunt_bus"}
for k in ddict_:
if k not in key_shunt_reco:
warn = "The key {} is not recognized by BaseAction when trying to modify the shunt.".format(
k
if "shunt" not in dict_:
return
ddict_ = dict_["shunt"]

key_shunt_reco = {"set_bus", "shunt_p", "shunt_q", "shunt_bus"}
for k in ddict_:
if k not in key_shunt_reco:
warn = "The key {} is not recognized by BaseAction when trying to modify the shunt.".format(
k
)
warn += " Recognized keys are {}".format(sorted(key_shunt_reco))
warnings.warn(warn)

for key_n, vect_self in zip(
["shunt_bus", "shunt_p", "shunt_q", "set_bus"],
[self.shunt_bus, self.shunt_p, self.shunt_q, self.shunt_bus],
):
if key_n in ddict_:
tmp = ddict_[key_n]
if tmp is None:
pass
elif key_n == "shunt_bus" or key_n == "set_bus":
self._aux_affect_object_int(
tmp,
key_n,
cls.n_shunt,
cls.name_shunt,
np.arange(cls.n_shunt),
vect_self,
max_val=cls.n_busbar_per_sub
)
elif key_n == "shunt_p" or key_n == "shunt_q":
self._aux_affect_object_float(
tmp,
key_n,
cls.n_shunt,
cls.name_shunt,
np.arange(cls.n_shunt),
vect_self
)

# if isinstance(tmp, np.ndarray):
# # complete shunt vector is provided
# vect_self[:] = tmp
# elif isinstance(tmp, list):
# # expected a list: (id shunt, new bus)
# for (sh_id, new_bus) in tmp:
# if sh_id < 0:
# raise AmbiguousAction(
# "Invalid shunt id {}. Shunt id should be positive".format(
# sh_id
# )
# )
# if sh_id >= cls.n_shunt:
# raise AmbiguousAction(
# "Invalid shunt id {}. Shunt id should be less than the number "
# "of shunt {}".format(sh_id, cls.n_shunt)
# )
# if key_n == "shunt_bus" or key_n == "set_bus":
# if new_bus <= -2:
# raise IllegalAction(
# f"Cannot ask for a shunt bus <= -2, found {new_bus} for shunt id {sh_id}"
# )
# elif new_bus > cls.n_busbar_per_sub:
# raise IllegalAction(
# f"Cannot ask for a shunt bus > {cls.n_busbar_per_sub} "
# f"the maximum number of busbar per substations"
# f", found {new_bus} for shunt id {sh_id}"
# )

# vect_self[sh_id] = new_bus
else:
raise AmbiguousAction(
"Invalid way to modify {} for shunts. It should be a numpy array or a "
"list, found {}.".format(key_n, type(tmp))
)
warn += " Recognized keys are {}".format(sorted(key_shunt_reco))
warnings.warn(warn)
for key_n, vect_self in zip(
["shunt_bus", "shunt_p", "shunt_q", "set_bus"],
[self.shunt_bus, self.shunt_p, self.shunt_q, self.shunt_bus],
):
if key_n in ddict_:
tmp = ddict_[key_n]
if isinstance(tmp, np.ndarray):
# complete shunt vector is provided
vect_self[:] = tmp
elif isinstance(tmp, list):
# expected a list: (id shunt, new bus)
for (sh_id, new_bus) in tmp:
if sh_id < 0:
raise AmbiguousAction(
"Invalid shunt id {}. Shunt id should be positive".format(
sh_id
)
)
if sh_id >= cls.n_shunt:
raise AmbiguousAction(
"Invalid shunt id {}. Shunt id should be less than the number "
"of shunt {}".format(sh_id, cls.n_shunt)
)
if key_n == "shunt_bus" or key_n == "set_bus":
if new_bus <= -2:
raise IllegalAction(
f"Cannot ask for a shunt bus <= -2, found {new_bus} for shunt id {sh_id}"
)
elif new_bus > cls.n_busbar_per_sub:
raise IllegalAction(
f"Cannot ask for a shunt bus > {cls.n_busbar_per_sub} "
f"the maximum number of busbar per substations"
f", found {new_bus} for shunt id {sh_id}"
)

vect_self[sh_id] = new_bus
elif tmp is None:
pass
else:
raise AmbiguousAction(
"Invalid way to modify {} for shunts. It should be a numpy array or a "
"dictionary.".format(key_n)
)

def _digest_injection(self, dict_):
# I update the action
Expand Down Expand Up @@ -2265,6 +2288,7 @@ def update(self,
- "curtail" : TODO
- "raise_alarm" : TODO
- "raise_alert": TODO
- "shunt": TODO
**NB**: CHANGES: you can reconnect a powerline without specifying on each bus you reconnect it at both its
ends. In that case the last known bus id for each its end is used.
Expand Down Expand Up @@ -4059,7 +4083,7 @@ def _aux_affect_object_int(
if len(values) == nb_els:
# 2 cases: either i set all loads in the form [(0,..), (1,..), (2,...)]
# or i should have converted the list to np array
if isinstance(values[0], tuple):
if isinstance(values[0], (tuple, list)):
# list of tuple, handled below
# TODO can be somewhat "hacked" if the type of the object on the list is not always the same
pass
Expand Down Expand Up @@ -5492,7 +5516,7 @@ def _aux_affect_object_float(
raise IllegalAction(
f"Impossible to set {name_el} values with a single float."
)
elif isinstance(values[0], tuple):
elif isinstance(values[0], (tuple, list)):
# list of tuple, handled below
# TODO can be somewhat "hacked" if the type of the object on the list is not always the same
pass
Expand Down
67 changes: 67 additions & 0 deletions grid2op/tests/test_issue_665.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2024, RTE (https://www.rte-france.com)
# See AUTHORS.txt and https://github.com/Grid2Op/grid2op/pull/319
# This Source Code Form is subject to the terms of the Mozilla Public License, version 2.0.
# If a copy of the Mozilla Public License, version 2.0 was not distributed with this file,
# you can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
# This file is part of Grid2Op, Grid2Op a testbed platform to model sequential decision making in power systems.

import numpy as np
from logging import Logger
import unittest
import warnings


from helper_path_test import PATH_DATA_TEST
import grid2op
from grid2op.dtypes import dt_int, dt_float
from grid2op.gym_compat import BoxGymObsSpace
from grid2op.gym_compat.utils import _compute_extra_power_for_losses
from grid2op.Exceptions import ChronicsError, EnvError


class Issue665Tester(unittest.TestCase):
def setUp(self):
self.env_name = "l2rpn_idf_2023"
# create first env
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
self.env = grid2op.make("l2rpn_idf_2023", test=True)
self.dict_properties = BoxGymObsSpace(self.env.observation_space)._dict_properties

def tearDown(self) -> None:
self.env.close()
return super().tearDown()

def test_issue_665(self):
attributes_names = set(self.dict_properties.keys())
attr_with_a_problem = set() # I put an attribute here if at least one bound has been exceeded at least once
attr_without_a_problem = set(self.dict_properties.keys()) # I remove an attribute from here if at least one bound has been exceeded at least once

i = 0
while i < 5 and not attr_without_a_problem:
obs = self.env.reset()
obs_temp = self.env.observation_space._template_obj

# I check only attributes which has not exceeded their bounds yet
for attr_name in attr_without_a_problem:
attr = getattr(obs_temp, attr_name)
low = self.dict_properties[attr_name][0]
high = self.dict_properties[attr_name][1]

ids = np.where((attr < low) | (attr > high))[0]
if ids.shape[0] > 0: # Case where at least a bound has been exceeded
# I uppdate my set
attr_with_a_problem.add(attr_name)
# I print a value (the one with the lower index) that exceeded its bounds
id0 = ids[0]
print(f"The {attr_name} attribute is out of the bounds with index {id0}. Bounds : {low[id0]} <= {high[id0]}, value: {attr[id0]}.")

# I uppdate my set
attr_without_a_problem = attributes_names - attr_with_a_problem
i+=1

assert not attr_with_a_problem

if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 85d2e3e

Please sign in to comment.