Skip to content

Commit

Permalink
Reduce nesting in analysis_ES
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Dec 19, 2023
1 parent 603daad commit c765c7f
Showing 1 changed file with 127 additions and 89 deletions.
216 changes: 127 additions & 89 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,74 @@ def _update_with_row_scaling(
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)


def _determine_inversion_type(ies_inversion: int, update_step_name: str) -> str:
inversion_types = {0: "exact", 1: "subspace", 2: "subspace", 3: "subspace"}

try:
inversion_type = inversion_types[ies_inversion]
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched inversion type for update step: {update_step_name}. "
f"Specified: {ies_inversion}, with possible: {list(inversion_types.keys())}"
) from e

return inversion_type


def _determine_parameter_source(
param_group_name: str, target_fs: EnsembleAccessor, source_fs: EnsembleReader
) -> Union[EnsembleReader, EnsembleAccessor]:
"""
Determines the source for a parameter group based on whether it is available in `target_fs`.
It is possible to update the same parameter multiple times.
For example, two update steps may be defined where one udpates the parameter using observation
`A` while the other udpates it using observation `B`.
After the processing of the first update step has completed, the updated parameter is stored in `traget_fs`.
Hence, when processing the second update step, we need to load the parameter from `target_fs` and not `source_fs`.
"""
if target_fs.has_parameter_group(param_group_name):
return target_fs
else:
return source_fs


def _calculate_adaptive_batch_size(num_params: int, num_obs: int) -> int:
"""Calculate adaptive batch size to optimize memory usage during Adaptive Localization
Adaptive Localization calculates the cross-covariance between parameters and responses.
Cross-covariance is a matrix with shape num_params x num_obs which may be larger than memory.
Therefore, a batching algorithm is used where only a subset of parameters is used when
calculating cross-covariance.
This function calculates a batch size that can fit into the available memory, accounting
for a safety margin.
The available memory is checked using the `psutil` library, which provides information about
system memory usage.
From `psutil` documentation:
- available:
the memory that can be given instantly to processes without the
system going into swap.
This is calculated by summing different memory values depending
on the platform and it is supposed to be used to monitor actual
memory usage in a cross platform fashion.
"""
available_memory_bytes = psutil.virtual_memory().available
memory_safety_factor = 0.8
bytes_in_float64 = 8
return min(
int(
np.floor(
available_memory_bytes
* memory_safety_factor
/ (num_obs * bytes_in_float64)
)
),
num_params,
)


def analysis_ES(
updatestep: UpdateConfiguration,
rng: np.random.Generator,
Expand All @@ -506,11 +574,14 @@ def analysis_ES(
misfit_process: bool,
) -> None:
iens_active_index = np.flatnonzero(ens_mask)

ensemble_size = ens_mask.sum()
updated_parameter_groups = []

for update_step in updatestep:
updated_parameter_groups.extend(
[param_group.name for param_group in update_step.parameters]
)

progress_callback(
AnalysisStatusEvent(msg="Loading observations and responses..")
)
Expand All @@ -533,6 +604,7 @@ def analysis_ES(
)
except IndexError as e:
raise ErtAnalysisError(e) from e

smoother_snapshot.update_step_snapshots[update_step.name] = update_snapshot

num_obs = len(observation_values)
Expand All @@ -541,25 +613,50 @@ def analysis_ES(
f"No active observations for update step: {update_step.name}."
)

inversion_types = {0: "exact", 1: "subspace", 2: "subspace", 3: "subspace"}
try:
inversion_type = inversion_types[module.ies_inversion]
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched inversion type for: "
f"Specified: {module.ies_inversion}, with possible: {inversion_types}"
) from e

smoother_es = ies.ESMDA(
covariance=observation_errors**2,
observations=observation_values,
alpha=1, # The user is responsible for scaling observation covariance (esmda usage)
seed=rng,
inversion=inversion_type,
inversion_type = _determine_inversion_type(
module.ies_inversion, update_step.name
)

truncation = module.enkf_truncation

if module.localization:
# If doing global update, i.e., udpating all parameters using all observations.
if not module.localization:
smoother_es = ies.ESMDA(
covariance=observation_errors**2,
observations=observation_values,
alpha=1, # The user is responsible for scaling observation covariance (esmda usage)
seed=rng,
inversion=inversion_type,
)
# Compute transition matrix so that
# X_posterior = X_prior @ (I + T)
T = smoother_es.compute_transition_matrix(
Y=S, alpha=1.0, truncation=truncation
)
# Add identity in place for fast computation
np.fill_diagonal(T, T.diagonal() + 1)

# One parameter group is updated at a time to save memory.
# We call this a "Streaming Algorithm".
for param_group in update_step.parameters:
source = _determine_parameter_source(
param_group.name, target_fs, source_fs
)
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)

# Update manually using global transition matrix T
if active_indices := param_group.index_list:
temp_storage[param_group.name][active_indices, :] @= T
else:
temp_storage[param_group.name] @= T

progress_callback(
AnalysisStatusEvent(msg=f"Storing data for {param_group.name}..")
)
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
else: # Adaptive Localization
smoother_adaptive_es = AdaptiveESMDA(
covariance=observation_errors**2,
observations=observation_values,
Expand All @@ -573,54 +670,16 @@ def analysis_ES(
ensemble_size=ensemble_size, alpha=1.0
)

else:
# Compute transition matrix so that
# X_posterior = X_prior @ T
T = smoother_es.compute_transition_matrix(
Y=S, alpha=1.0, truncation=truncation
)
# Add identity in place for fast computation
np.fill_diagonal(T, T.diagonal() + 1)

for param_group in update_step.parameters:
updated_parameter_groups.append(param_group.name)
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
source = target_fs
else:
source = source_fs
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
if module.localization:
for param_group in update_step.parameters:
source = _determine_parameter_source(
param_group.name, target_fs, source_fs
)
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
num_params = temp_storage[param_group.name].shape[0]

# Calculate adaptive batch size.
# Adaptive Localization calculates the cross-covariance between
# parameters and responses.
# Cross-covariance is a matrix with shape num_params x num_obs
# which may be larger than memory.

# From `psutil` documentation:
# - available:
# the memory that can be given instantly to processes without the
# system going into swap.
# This is calculated by summing different memory values depending
# on the platform and it is supposed to be used to monitor actual
# memory usage in a cross platform fashion.
available_memory_bytes = psutil.virtual_memory().available
memory_safety_factor = 0.8
bytes_in_float64 = 8
batch_size = min(
int(
np.floor(
available_memory_bytes
* memory_safety_factor
/ (num_obs * bytes_in_float64)
)
),
num_params,
)
batch_size = _calculate_adaptive_batch_size(num_params, num_obs)

batches = _split_by_batchsize(np.arange(0, num_params), batch_size)

Expand All @@ -646,27 +705,10 @@ def analysis_ES(
f"Adaptive Localization of {param_group} completed in {(time.time() - start) / 60} minutes"
)

else:
# Use low-level ies API to allow looping over parameters
if active_indices := param_group.index_list:
# The batch of parameters
X_local = temp_storage[param_group.name][active_indices, :]

# Update manually using global transition matrix T
temp_storage[param_group.name][active_indices, :] = X_local @ T

else:
# Update manually using global transition matrix T
temp_storage[param_group.name] @= T

log_msg = f"Storing data for {param_group.name}.."
_logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))
start = time.time()
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
_logger.info(
f"Storing data for {param_group.name} completed in {(time.time() - start) / 60} minutes"
)
progress_callback(
AnalysisStatusEvent(msg=f"Storing data for {param_group.name}..")
)
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)

# Finally, if some parameter groups have not been updated we need to copy the parameters
# from the parent ensemble.
Expand Down Expand Up @@ -808,11 +850,7 @@ def analysis_IES(

for param_group in update_step.parameters:
updated_parameter_groups.append(param_group.name)
source: Union[EnsembleReader, EnsembleAccessor] = target_fs
try:
target_fs.load_parameters(group=param_group.name, realizations=0)
except Exception:
source = source_fs
source = _determine_parameter_source(param_group.name, target_fs)

Check failure on line 853 in src/ert/analysis/_es_update.py

View workflow job for this annotation

GitHub Actions / type-checking (3.11)

Missing positional argument "source_fs" in call to "_determine_parameter_source"
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
Expand Down

0 comments on commit c765c7f

Please sign in to comment.