From 9f99e9352af2a6a63cc7c3527ae9a9cf88ad2e1b Mon Sep 17 00:00:00 2001 From: Frode Aarstad Date: Fri, 1 Dec 2023 09:15:54 +0100 Subject: [PATCH] Add style and fix last tests --- src/ert/analysis/_es_update.py | 4 +++- src/ert/config/summary_config.py | 2 +- src/ert/data/_measured_data.py | 21 ++++++++++++------- src/ert/libres_facade.py | 7 +++---- .../unit_tests/data/test_integration_data.py | 7 ++++--- tests/unit_tests/test_load_forward_model.py | 9 ++++---- 6 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index d79a4acb017..a207b0452ed 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -331,7 +331,9 @@ def _get_obs_and_measure_data( ds = source_fs.load_responses(group, tuple(iens_active_index)) if "time" in observation.coords: - observation.coords["time"]= [t[:-3] for t in observation.coords["time"].values.astype(str)] + observation.coords["time"] = [ + t[:-3] for t in observation.coords["time"].values.astype(str) + ] try: filtered_ds = observation.merge(ds, join="left") diff --git a/src/ert/config/summary_config.py b/src/ert/config/summary_config.py index 879578d8cea..8a79c1da64d 100644 --- a/src/ert/config/summary_config.py +++ b/src/ert/config/summary_config.py @@ -62,7 +62,7 @@ def read_from_file(self, run_path: str, iens: int) -> xr.Dataset: summary_data.sort(key=lambda x: x[0]) data = [d for _, d in summary_data] keys = [k for k, _ in summary_data] - time_map = [datetime.isoformat(t, timespec="microseconds") for t in time_map] + time_map = [datetime.isoformat(t, timespec="microseconds") for t in time_map] ds = xr.Dataset( {"values": (["name", "time"], data)}, coords={"time": time_map, "name": keys}, diff --git a/src/ert/data/_measured_data.py b/src/ert/data/_measured_data.py index 9960d9f504b..cb16c6f6870 100644 --- a/src/ert/data/_measured_data.py +++ b/src/ert/data/_measured_data.py @@ -114,9 +114,11 @@ def _get_data( raise ResponseError(_msg) except KeyError as e: raise ResponseError(_msg) from e - + if "time" in obs.coords: - obs.coords["time"]= [t[:-3] for t in obs.coords["time"].values.astype(str)] + obs.coords["time"] = [ + t[:-3] for t in obs.coords["time"].values.astype(str) + ] ds = obs.merge( response, @@ -134,11 +136,11 @@ def _get_data( ds = ds.rename(time="key_index") ds = ds.assign_coords({"name": [key]}) - new_index = pd.DatetimeIndex(response.indexes["time"].values.astype('datetime64[ns]')) - data_index = [ - new_index.get_loc(date) for date in obs.time.values - ] - #data_index = [response.indexes["time"].get_loc(date) for date in obs.time.values ] + new_index = pd.DatetimeIndex( + response.indexes["time"].values.astype("datetime64[ns]") + ) + data_index = [new_index.get_loc(date) for date in obs.time.values] + # data_index = [response.indexes["time"].get_loc(date) for date in obs.time.values ] index_vals = ds.observations.coords.to_index( ["name", "key_index"] @@ -210,7 +212,10 @@ def _create_condition( for obs_key, index_list in zip(obs_keys, index_lists): if index_list is not None: if isinstance(index_list[0], datetime): - index_list= [datetime.isoformat(t, timespec="microseconds") for t in index_list] + index_list = [ + datetime.isoformat(t, timespec="microseconds") + for t in index_list + ] index_cond = [data_index == index for index in index_list] index_cond = np.logical_or.reduce(index_cond) conditions.append(np.logical_and(index_cond, (names == obs_key))) diff --git a/src/ert/libres_facade.py b/src/ert/libres_facade.py index 78b3d5d3872..a6e5b583206 100644 --- a/src/ert/libres_facade.py +++ b/src/ert/libres_facade.py @@ -416,16 +416,15 @@ def load_all_summary_data( ) except (ValueError, KeyError): return pd.DataFrame() + + # Remove the time part of the 'time' index + df.index = df.index.set_levels([t[:-16] for t in df.index.levels[0]], level=0) df = df.unstack(level="name") df.columns = [col[1] for col in df.columns.values] df.index = df.index.rename( {"time": "Date", "realization": "Realization"} ).reorder_levels(["Realization", "Date"]) - # remove time part - - - if keys: summary_keys = sorted( [key for key in keys if key in summary_keys] diff --git a/tests/unit_tests/data/test_integration_data.py b/tests/unit_tests/data/test_integration_data.py index 8ce3263941c..0600ac3cc59 100644 --- a/tests/unit_tests/data/test_integration_data.py +++ b/tests/unit_tests/data/test_integration_data.py @@ -48,9 +48,10 @@ def test_summary_obs(create_measured_data): summary_obs.remove_inactive_observations() assert all(summary_obs.data.columns.get_level_values("data_index").values == [71]) # Only one observation, we check the key_index is what we expect: - assert summary_obs.data.columns.get_level_values("key_index").values[ - 0 - ] == "2011-12-21T00:00:00.000000" + assert ( + summary_obs.data.columns.get_level_values("key_index").values[0] + == "2011-12-21T00:00:00.000000" + ) @pytest.mark.filterwarnings("ignore::ert.config.ConfigWarning") diff --git a/tests/unit_tests/test_load_forward_model.py b/tests/unit_tests/test_load_forward_model.py index 4f82b82cdf4..edaae0f8c95 100644 --- a/tests/unit_tests/test_load_forward_model.py +++ b/tests/unit_tests/test_load_forward_model.py @@ -1,5 +1,3 @@ - -import xarray as xr import fileinput import logging import os @@ -9,6 +7,7 @@ import numpy as np import pytest +import xarray as xr from resdata.summary import Summary from ert.config import ErtConfig @@ -142,9 +141,9 @@ def test_datetime_2500(): realizations = [False] * facade.get_ensemble_size() realizations[realisation_number] = True facade.load_from_forward_model(ensemble, realizations, 0) - - dataset= ensemble.load_responses("summary", tuple([0])) - assert dataset.coords["time"].data.dtype == np.dtype('object') + + dataset = ensemble.load_responses("summary", tuple([0])) + assert dataset.coords["time"].data.dtype == np.dtype("object") @pytest.mark.usefixtures("copy_snake_oil_case_storage")