diff --git a/Grid2Kpi.iml b/Grid2Kpi.iml new file mode 100644 index 0000000..ad3c0a3 --- /dev/null +++ b/Grid2Kpi.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/grid2kpi/episode/EpisodeAnalytics.py b/grid2kpi/episode/EpisodeAnalytics.py index bb5f1be..d76d692 100644 --- a/grid2kpi/episode/EpisodeAnalytics.py +++ b/grid2kpi/episode/EpisodeAnalytics.py @@ -272,6 +272,15 @@ def _env_actions_as_df(self): return hazards, maintenances + def get_prod_types(self): + types = self.observation_space.gen_type + ret = {} + if types is None: + return ret + for (idx, name) in enumerate(self.prod_names): + ret[name] = types[idx] + return ret + class Test(): def __init__(self): diff --git a/grid2kpi/episode/EpisodeTrace.py b/grid2kpi/episode/EpisodeTrace.py index 3839b65..b3a7259 100644 --- a/grid2kpi/episode/EpisodeTrace.py +++ b/grid2kpi/episode/EpisodeTrace.py @@ -24,7 +24,8 @@ def get_total_overflow_ts(episode): return df -def get_prod_share_trace(episode, prod_types): +def get_prod_share_trace(episode): + prod_types = episode.get_prod_types() prod_type_values = list(prod_types.values()) if len( prod_types.values()) > 0 else [] @@ -106,7 +107,7 @@ def get_all_prod_trace(episode, prod_types, selection): if name in selection: trace.append(go.Scatter( x=prod_with_type[prod_with_type.prod_type.values == - name]['timestamp'].unique(), + name]['timestamp'].drop_duplicates(), y=prod_with_type[prod_with_type.prod_type.values == name].groupby(['timestamp'])[ 'value'].sum(), name=name