From a31ae59ea8d50fcc8e66e0de0230b7d40ba5e9c0 Mon Sep 17 00:00:00 2001 From: Blunde1 Date: Wed, 15 May 2024 11:36:43 +0200 Subject: [PATCH] go down to working levels --- src/ert/analysis/_es_update.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index c36e1796f7f..c570bee1f42 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -34,6 +34,7 @@ from iterative_ensemble_smoother.experimental import ( AdaptiveESMDA, ) +from sklearn.preprocessing import StandardScaler from typing_extensions import Self from ..config.analysis_module import ESSettings, IESSettings @@ -667,8 +668,8 @@ def adaptive_localization_progress_callback( source, iens_active_index, param_group ) X_local = temp_storage[param_group] - # X_local_scaler = StandardScaler() - # X_scaled = X_local_scaler.fit_transform(X_local.T) + X_local_scaler = StandardScaler() + X_scaled = X_local_scaler.fit_transform(X_local.T) # scaler_cache[param_group] = X_local_scaler graph_u_sub = config_node.load_parameter_graph( @@ -683,13 +684,13 @@ def adaptive_localization_progress_callback( Z = sp.sparse.csc_matrix(Z) Prec_u_sub = gspme.prec_sparse( - X_local.T, + X_scaled, Z, markov_order=1, cov_shrinkage=True, symmetrization=False, shrinkage_target=2, - inflation_factor=15.0, + inflation_factor=10.0, ) # # A very simple hash key for graph @@ -770,13 +771,13 @@ def adaptive_localization_progress_callback( ) as file: pickle.dump(observation_errors, file) - # X_full_scaler = StandardScaler() - # X_full_scaled = X_full_scaler.fit_transform(X_full.T) - # print(f"Scaled X_full has shape: {X_full_scaled.shape}") + X_full_scaler = StandardScaler() + X_full_scaled = X_full_scaler.fit_transform(X_full.T) + print(f"Scaled X_full has shape: {X_full_scaled.shape}") # Call fit: Learn sparse linear map only H = linear_boost_ic_regression( - U=X_full.T, + U=X_full_scaled, Y=S.T, learning_rate=0.95, effective_dimension=0, @@ -803,7 +804,7 @@ def adaptive_localization_progress_callback( # Call transport? might have to do some coding here # Perhaps use an iterative solver instead of direct spsolve or similar X_full = gtmap.transport( - X_full.T, + X_full_scaled, S.T, observation_values, update_indices=update_indices,