Skip to content

Commit

Permalink
go down to working levels
Browse files Browse the repository at this point in the history
  • Loading branch information
Blunde1 committed May 15, 2024
1 parent 281c0cf commit a31ae59
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from iterative_ensemble_smoother.experimental import (
AdaptiveESMDA,
)
from sklearn.preprocessing import StandardScaler
from typing_extensions import Self

from ..config.analysis_module import ESSettings, IESSettings
Expand Down Expand Up @@ -667,8 +668,8 @@ def adaptive_localization_progress_callback(
source, iens_active_index, param_group
)
X_local = temp_storage[param_group]
# X_local_scaler = StandardScaler()
# X_scaled = X_local_scaler.fit_transform(X_local.T)
X_local_scaler = StandardScaler()
X_scaled = X_local_scaler.fit_transform(X_local.T)
# scaler_cache[param_group] = X_local_scaler

graph_u_sub = config_node.load_parameter_graph(
Expand All @@ -683,13 +684,13 @@ def adaptive_localization_progress_callback(
Z = sp.sparse.csc_matrix(Z)

Prec_u_sub = gspme.prec_sparse(
X_local.T,
X_scaled,
Z,
markov_order=1,
cov_shrinkage=True,
symmetrization=False,
shrinkage_target=2,
inflation_factor=15.0,
inflation_factor=10.0,
)

# # A very simple hash key for graph
Expand Down Expand Up @@ -770,13 +771,13 @@ def adaptive_localization_progress_callback(
) as file:
pickle.dump(observation_errors, file)

# X_full_scaler = StandardScaler()
# X_full_scaled = X_full_scaler.fit_transform(X_full.T)
# print(f"Scaled X_full has shape: {X_full_scaled.shape}")
X_full_scaler = StandardScaler()
X_full_scaled = X_full_scaler.fit_transform(X_full.T)
print(f"Scaled X_full has shape: {X_full_scaled.shape}")

# Call fit: Learn sparse linear map only
H = linear_boost_ic_regression(
U=X_full.T,
U=X_full_scaled,
Y=S.T,
learning_rate=0.95,
effective_dimension=0,
Expand All @@ -803,7 +804,7 @@ def adaptive_localization_progress_callback(
# Call transport? might have to do some coding here
# Perhaps use an iterative solver instead of direct spsolve or similar
X_full = gtmap.transport(
X_full.T,
X_full_scaled,
S.T,
observation_values,
update_indices=update_indices,
Expand Down

0 comments on commit a31ae59

Please sign in to comment.