Skip to content

Commit

Permalink
efficiently remove nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Blunde1 committed Jul 23, 2024
1 parent 6996fba commit 7e8fc5c
Showing 1 changed file with 12 additions and 31 deletions.
43 changes: 12 additions & 31 deletions src/ert/config/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,45 +58,26 @@ def create_flattened_cube_graph(px: int, py: int, pz: int) -> nx.Graph:
return G


def adjust_graph_for_masking(
G: nx.Graph, mask: npt.NDArray[np.bool_], marginalize: bool = False
):
def adjust_graph_for_masking(G: nx.Graph, mask: npt.NDArray[np.bool_]):
"""
Adjust the graph G according to the masking indices.
For each masked index, its neighbors become neighbors of each other,
then the masked index is removed from the graph. After each removal,
nodes with an index greater than the removed node are decremented by 1.
Removes nodes specified by the mask and relabels the remaining nodes
to have consecutive labels from 0 to G.number_of_nodes - 1.
Parameters:
- G: The graph to adjust
- mask_indices: Indices to mask, assumed to be sorted in ascending order
- mask: Boolean mask flattened array
Returns:
- The adjusted graph
"""
print(f"Mask size: {mask.size} shape {mask.shape}")
print(f"mask itself: {mask}")
# Step 1: Remove nodes specified by mask_indices
mask_indices = np.where(mask)[0]
print(f"mask_indices: {mask_indices}")
for removed_count, i in enumerate(mask_indices):
# Adjust i for the number of removals to get the current index in the graph
print(f"removed_count: {removed_count}, i: {i}")
current_index = i - removed_count

if marginalize:
# Make neighbors of the current node neighbors of each other
neighbors = list(G.neighbors(current_index))
for u in neighbors:
for v in neighbors:
if u != v and not G.has_edge(u, v):
G.add_edge(u, v)

# Remove the current node
G.remove_node(current_index)

# Decrement indices of nodes greater than the current node
mapping = {
node: (node - 1 if node >= current_index else node) for node in G.nodes()
}
nx.relabel_nodes(G, mapping, copy=False)
G.remove_nodes_from(mask_indices)

# Step 2: Relabel remaining nodes to 0, 1, 2, ..., G.number_of_nodes - 1
new_labels = {old_label: new_label for new_label, old_label in enumerate(G.nodes())}
G = nx.relabel_nodes(G, new_labels, copy=False)

return G

Expand Down

0 comments on commit 7e8fc5c

Please sign in to comment.