Skip to content

Commit

Permalink
add graph information to ert parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Blunde1 committed Mar 19, 2024
1 parent a193809 commit 06bf432
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def _create_temporary_parameter_storage(
print(f"prior_xdata: {prior_xdata.coords}")
print(f"prior_xdata: {prior_xdata.dims}")
print(f"prior_xdata values: {values_shape}")
graph = config_node.load_parameter_graph(ensemble, param_group, iens_active_index)
print(f"graph nodes: {len(list(graph.nodes))}")
print(f"graph edges: {len(list(graph.edges))}")
####

temp_storage[param_group] = config_node.load_parameters(
Expand Down Expand Up @@ -535,6 +538,8 @@ def adaptive_localization_progress_callback(
# Add identity in place for fast computation
np.fill_diagonal(T, T.diagonal() + 1)

# Load all parameters _and_ graphs at the same time

for param_group in parameters:
source = source_ensemble
temp_storage = _create_temporary_parameter_storage(
Expand Down
8 changes: 8 additions & 0 deletions src/ert/config/ext_param_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Mapping, MutableMapping, Tuple, Union

import networkx as nx
import numpy as np
import xarray as xr

from ert.storage import Ensemble

from .parameter_config import ParameterConfig

if TYPE_CHECKING:
Expand Down Expand Up @@ -106,6 +109,11 @@ def load_parameters(
) -> Union[npt.NDArray[np.float_], xr.DataArray]:
raise NotImplementedError()

def load_parameter_graph(
self, ensemble: Ensemble, group: str, realizations: npt.NDArray[np.int_]
) -> nx.Graph:
raise NotImplementedError()

@staticmethod
def to_dataset(data: DataType) -> xr.Dataset:
"""Flattens data to fit inside a dataset"""
Expand Down
82 changes: 82 additions & 0 deletions src/ert/config/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional, Union, overload

import networkx as nx
import numpy as np
import xarray as xr
from typing_extensions import Self
Expand All @@ -27,6 +28,77 @@
_logger = logging.getLogger(__name__)


def create_flattened_cube_graph(px: int, py: int, pz: int) -> nx.Graph:
"""graph created with nodes numbered from 0 to px*py*pz
corresponds to the "vectorization" or flattening of
a 3D cube with shape (px,py,pz) in the same way as
reshaping such a cube into a one-dimensional array.
The indexing scheme used to create the graph reflects
this flattening process"""

G = nx.Graph()
for x in range(px):
for y in range(py):
for z in range(pz):
# Flatten the 3D index to a single index
index = x * py * pz + y * pz + z

# Connect to the right neighbor (y-direction)
if y < py - 1:
G.add_edge(index, index + pz)

# Connect to the bottom neighbor (x-direction)
if x < px - 1:
G.add_edge(index, index + py * pz)

# Connect to the neighbor in front (z-direction)
if z < pz - 1:
G.add_edge(index, index + 1)

return G


def adjust_graph_for_masking(G: nx.Graph, mask_indices: npt.NDArray[np.int_]):
"""
Adjust the graph G according to the masking indices.
For each masked index, its neighbors become neighbors of each other,
then the masked index is removed from the graph. After each removal,
nodes with an index greater than the removed node are decremented by 1.
Parameters:
- G: The graph to adjust
- mask_indices: Indices to mask, assumed to be sorted in ascending order
Returns:
- The adjusted graph
"""
removed_count = 0

for i in mask_indices:
# Adjust i for the number of removals to get the current index in the graph
current_index = i - removed_count
neighbors = list(G.neighbors(current_index))

# Make neighbors of the current node neighbors of each other
for u in neighbors:
for v in neighbors:
if u != v and not G.has_edge(u, v):
G.add_edge(u, v)

# Remove the current node
G.remove_node(current_index)

# Decrement indices of nodes greater than the current node
mapping = {
node: (node - 1 if node > current_index else node) for node in G.nodes()
}
nx.relabel_nodes(G, mapping, copy=False)

removed_count += 1

return G


@dataclass
class Field(ParameterConfig):
nx: int
Expand Down Expand Up @@ -209,6 +281,16 @@ def load_parameters(
)
return da.T.to_numpy()

def load_parameter_graph(
self, ensemble: Ensemble, group: str, realizations: npt.NDArray[np.int_]
) -> nx.Graph: # type: ignore
parameter_graph = create_flattened_cube_graph(
px=self.nx, py=self.ny, pz=self.nz
)
return adjust_graph_for_masking(
G=parameter_graph, mask_indices=np.where(self.mask)[0]
)

def _fetch_from_ensemble(self, real_nr: int, ensemble: Ensemble) -> xr.DataArray:
da = ensemble.load_parameters(self.name, real_nr)["values"]
assert isinstance(da, xr.DataArray)
Expand Down
11 changes: 11 additions & 0 deletions src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
overload,
)

import networkx as nx
import numpy as np
import pandas as pd
import xarray as xr
Expand Down Expand Up @@ -302,6 +303,16 @@ def load_parameters(
) -> Union[npt.NDArray[np.float_], xr.DataArray]:
return ensemble.load_parameters(group, realizations)["values"].values.T

def load_parameter_graph(
self, ensemble: Ensemble, group: str, realizations: npt.NDArray[np.int_]
) -> nx.Graph:
sample = ensemble.load_parameters(group, realizations)["values"].values
_, p = sample.shape
# Create a graph with no edges
graph_independence = nx.Graph()
graph_independence.add_nodes_from(range(p))
return graph_independence

def shouldUseLogScale(self, keyword: str) -> bool:
for tf in self.transfer_functions:
if tf.name == keyword:
Expand Down
10 changes: 10 additions & 0 deletions src/ert/config/parameter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import networkx as nx
import numpy as np
import xarray as xr

Expand Down Expand Up @@ -105,6 +106,15 @@ def load_parameters(
Load the parameter from internal storage for the given ensemble
"""

@abstractmethod
def load_parameter_graph(
self, ensemble: Ensemble, group: str, realizations: npt.NDArray[np.int_]
) -> nx.Graph:
"""
Load the graph encoding Markov properties on the parameter `group`
Often a neighbourhood graph.
"""

def to_dict(self) -> Dict[str, Any]:
data = dataclasses.asdict(self, dict_factory=CustomDict)
data["_ert_kind"] = self.__class__.__name__
Expand Down
27 changes: 27 additions & 0 deletions src/ert/config/surface_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, List, Union

import networkx as nx
import numpy as np
import xarray as xr
import xtgeo
Expand All @@ -20,6 +21,27 @@
from ert.storage import Ensemble


def create_flattened_2d_graph(px: int, py: int) -> nx.Graph:
"""Graph created with nodes numbered from 0 to px*py
corresponds to the "vectorization" or flattening of
a 2D cube with shape (px,py) in the same way as
reshaping such a surface into a one-dimensional array.
The indexing scheme used to create the graph reflects
this flattening process"""

G = nx.Graph()
for i in range(px):
for j in range(py):
index = i * py + j # Flatten the 2D index to a single index
# Connect to the right neighbor
if j < py - 1:
G.add_edge(index, index + 1)
# Connect to the bottom neighbor
if i < px - 1:
G.add_edge(index, index + py)
return G


@dataclass
class SurfaceConfig(ParameterConfig):
ncol: int
Expand Down Expand Up @@ -162,3 +184,8 @@ def load_parameters(
self, ensemble: Ensemble, group: str, realizations: npt.NDArray[np.int_]
) -> Union[npt.NDArray[np.float_], xr.DataArray]:
return ensemble.load_parameters(group, realizations)["values"]

def load_parameter_graph(
self, ensemble: Ensemble, group: str, realizations: npt.NDArray[np.int_]
) -> nx.Graph: # type: ignore
return create_flattened_2d_graph(px=self.ncol, py=self.nrow)

0 comments on commit 06bf432

Please sign in to comment.